Skip to content

Commit 2587d39

Browse files
committed
small edits to improve figure 3
1 parent 4db292f commit 2587d39

File tree

2 files changed

+77
-21
lines changed

2 files changed

+77
-21
lines changed

graphs.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,52 @@
33

44

55
def customize_plot_fossolo(experiment):
6-
experiment.plot_constraints = True
6+
experiment.plot_constraints = False
77
experiment.plot_y_legend_label = ["DeePC"]
88
experiment.plot_u_legend_label = ["DeePC"]
99
for s in experiment.compare_signals:
1010
s["plot"] = True
1111
experiment.plot_legend_cols = 5
1212
fig = experiment.plot(experiment.wds)
1313
fig, u, y = experiment.plot_comparison_signals(fig, signals=experiment.compare_signals)
14+
15+
u, y = experiment.run_pid(kp=0.01, kd=0.1, ki=1, T=experiment.experiment_horizon+experiment.n_train)
16+
17+
_u = u[-experiment.experiment_horizon:, :]
18+
_y = y[-experiment.experiment_horizon:, :]
19+
20+
mae, me = experiment.get_error(y=_y)
21+
ref = np.ones(_y.shape) * experiment.y_ref
22+
cost = np.linalg.norm(_y - ref, 'fro') ** 2 + experiment.input_loss * np.linalg.norm(_u, 'fro') ** 2
23+
v_count, v_rate = experiment.get_violations(y=_y)
24+
print(f"PID --> Cost: {cost:.3f} | MAE: {mae:.3f} | ME: {me:.3f}")
25+
print(f"PID --> {experiment.y_lb}-{experiment.y_ub} Violations: {v_count:.0f}"
26+
f" | Violations Rate: {v_rate:.3f}")
27+
v_count, v_rate = experiment.get_violations(y=_y, lb=20, ub=40)
28+
print(f"PID --> {20}-{40} Violations: {v_count:.0f} | Violations Rate: {v_rate:.3f}")
29+
1430
axes = fig.axes
15-
axes[0].set_ylim(6, 44)
16-
axes[1].set_ylim(26, 65)
31+
axes[0].fill_between(range(experiment.n_train, experiment.experiment_horizon+experiment.n_train+2),
32+
experiment.y_lb, experiment.y_ub,
33+
alpha=0.2, color='grey', label="Constraints")
34+
axes[0].plot(y, '#35AC78', label="PID", zorder=3, linewidth=1.2)
35+
axes[1].step(u, '#35AC78', label="PID", zorder=3, where='post', linewidth=1.2)
36+
37+
handles, labels = axes[0].get_legend_handles_labels()
38+
order = [1, 4, 0, 2, 3, 5] # reorder the legend items - after adding PID
39+
axes[0].legend([handles[i] for i in order], [labels[i] for i in order], ncol=6, columnspacing=0.75, handletextpad=0.2)
40+
41+
# handles, labels = axes[1].get_legend_handles_labels()
42+
# order = [0, 2, 3, 5] # reorder the legend items - after adding PID
43+
# axes[1].legend([handles[i] for i in order], [labels[i] for i in order], ncol=4, columnspacing=0.75, handletextpad=0.15)
44+
axes[1].legend(ncol=4, columnspacing=0.75, handletextpad=0.15)
45+
46+
axes[0].set_ylim(8, 42)
47+
axes[1].set_ylim(30, 66)
1748
axes[2].set_ylim(0.5, 1.6)
1849
axes[-1].set_xlim(experiment.n_train, experiment.n_train + experiment.experiment_horizon + 4)
1950
fig.set_size_inches(8, 6)
20-
plt.subplots_adjust(left=0.125, right=0.9, bottom=0.1, top=0.9)
21-
22-
mae, me = experiment.get_error(y=y)
23-
ref = np.ones(y.shape) * experiment.y_ref
24-
cost = np.linalg.norm(y - ref, 'fro') ** 2 + experiment.input_loss * np.linalg.norm(u, 'fro') ** 2
25-
v_count, v_rate = experiment.get_violations(y=y)
26-
print(f"{s['name']} --> Cost: {cost:.3f} | MAE: {mae:.3f} | ME: {me:.3f}")
27-
print(f"{s['name']} --> {experiment.y_lb}-{experiment.y_ub} Violations: {v_count:.0f}"
28-
f" | Violations Rate: {v_rate:.3f}")
29-
v_count, v_rate = experiment.get_violations(y=y, lb=20, ub=40)
30-
print(f"{s['name']} --> {20}-{40} Violations: {v_count:.0f} | Violations Rate: {v_rate:.3f}")
51+
plt.subplots_adjust(left=0.125, right=0.9, bottom=0.1, top=0.9, hspace=0.1)
3152

3253

3354
def customize_plot_pescara(experiment):

main.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
import graphs
1313
import system
1414
import utils
15+
from pid import PID
1516

16-
COLORS = ["#0077B8", "#DF5353", "#fdc85e", "#b7b3aa"]
17+
COLORS = ["#0686cc", "#DF5353", "#fdc85e", "#b7b3aa"]#0077B8
1718
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=COLORS)
1819

1920

@@ -105,7 +106,7 @@ def run_experiment(self):
105106

106107
def get_error(self, y, n=None):
107108
if not isinstance(self.y_ref, np.ndarray):
108-
y_ref = np.array([self.y_ref])
109+
y_ref = np.tile(self.y_ref, (len(y), 1))
109110
else:
110111
y_ref = self.y_ref
111112

@@ -139,16 +140,16 @@ def plot(self, sys, fig=None, moving_avg_size=0):
139140
k = 0
140141

141142
if fig is None:
142-
inputs_heights = [1 for _ in range(n_inputs + k)]
143+
inputs_heights = [1 for _ in range(n_inputs)] + [0.5 for _ in range(k)]
143144
fig, axes = plt.subplots(nrows=n_inputs + 1 + k, sharex=True, height_ratios=[1] + inputs_heights)
144145
else:
145146
axes = fig.axes
146147

147148
t = len(sys.target_values[1:, 0])
148149
for i, element in enumerate(sys.target_nodes):
149-
axes[0].plot(sys.target_values[1:, i], label=self.plot_y_legend_label[i], zorder=4)
150+
axes[0].plot(sys.target_values[1:, i], label=self.plot_y_legend_label[i], zorder=10)
150151

151-
axes[0].hlines(y=self.y_ref, xmin=0, xmax=t, color='k', zorder=5, label="$y_{ref}$")
152+
axes[0].hlines(y=self.y_ref, xmin=0, xmax=t, color='k', zorder=1, linestyle='--', label="$y_{ref}$")
152153
axes[0].axvspan(0, self.n_train, facecolor='grey', alpha=0.3, zorder=0)
153154
axes[0].grid(True)
154155
axes[0].set_ylabel(utils.split_label(self.output_y_label, self.plot_y_labels_max_len))
@@ -223,9 +224,20 @@ def plot_comparison_signals(self, fig, signals: list):
223224
y = sys.target_values[-self.experiment_horizon:, :]
224225

225226
t = len(sys.target_values[1:, 0])
227+
228+
mae, me = self.get_error(y=y)
229+
ref = np.ones(y.shape) * self.y_ref
230+
cost = np.linalg.norm(y - ref, 'fro') ** 2 + self.input_loss * np.linalg.norm(u, 'fro') ** 2
231+
v_count, v_rate = self.get_violations(y=y)
232+
print(f"{s['name']} --> Cost: {cost:.3f} | MAE: {mae:.3f} | ME: {me:.3f}")
233+
print(f"{s['name']} --> {self.y_lb}-{self.y_ub} Violations: {v_count:.0f}"
234+
f" | Violations Rate: {v_rate:.3f}")
235+
v_count, v_rate = self.get_violations(y=y, lb=20, ub=40)
236+
print(f"{s['name']} --> {20}-{40} Violations: {v_count:.0f} | Violations Rate: {v_rate:.3f}")
237+
226238
if "plot" in s and s["plot"]:
227239
for i, element in enumerate(sys.target_nodes):
228-
axes[0].plot(sys.target_values[1:, i], label=s["name"], zorder=2)
240+
axes[0].plot(sys.target_values[1:, i], label=s["name"], zorder=2, linewidth=1.2)
229241
if self.plot_legend_cols > 0:
230242
# axes[0].legend(ncols=self.plot_legend_cols, fontsize=10)
231243
handles, labels = axes[0].get_legend_handles_labels()
@@ -234,7 +246,7 @@ def plot_comparison_signals(self, fig, signals: list):
234246
fontsize=9)
235247

236248
for i, element in enumerate(sys.control_nodes + sys.control_links):
237-
axes[1].step(range(t), sys.implemented[1:, i], label=s["name"], zorder=2, where='post')
249+
axes[1].step(range(t), sys.implemented[1:, i], label=s["name"], zorder=2, where='post', linewidth=1.2)
238250
if self.plot_legend_cols > 0:
239251
# axes[1].legend(ncols=self.plot_legend_cols, fontsize=10)
240252
handles, labels = axes[1].get_legend_handles_labels()
@@ -244,6 +256,29 @@ def plot_comparison_signals(self, fig, signals: list):
244256

245257
return fig, u, y
246258

259+
def run_pid(self, kp, ki, kd, T):
260+
sys = system.WDSControl(inp_path=self.inp_path,
261+
control_links=self.control_links,
262+
control_nodes=self.control_nodes,
263+
target_nodes=self.target_nodes,
264+
target_param=self.target_param
265+
)
266+
267+
pid = PID(kp=kp, ki=ki, kd=kd, set_point=self.y_ref)
268+
current_value = 0.5 * (self.u_lb + self.u_ub) # Initial value
269+
for _ in range(T):
270+
u_optimal = pid.compute(current_value=current_value)
271+
sys.apply_input(u=np.array([u_optimal]), noise_std=self.noise_std)
272+
current_value = sys.get_last_n_samples(1).y[0]
273+
274+
y = sys.target_values
275+
u = sys.implemented
276+
ref = np.ones(y.shape) * self.y_ref
277+
self.cost = np.linalg.norm(y - ref, 'fro') ** 2 + self.input_loss * np.linalg.norm(u, 'fro') ** 2
278+
self.mae, self.me = self.get_error(y=y)
279+
self.v_count, self.v_rate = self.get_violations(y=y)
280+
return u, y
281+
247282

248283
def run_comparable_signal(sys, ref_input_signal, noise_std):
249284
ref_sys = system.WDSControl(inp_path=sys.inp_path,

0 commit comments

Comments
 (0)