Skip to content

Commit aa6ec88

Browse files
committed
improved results plotting
1 parent 5459406 commit aa6ec88

File tree

7 files changed

+70
-45
lines changed

7 files changed

+70
-45
lines changed

proj/environment/environment.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,23 @@ def plan(self, curr_x, g_traj, itern):
108108
end = traj_length
109109

110110
# Make sure planned trajectory has the correct length
111+
if start == end:
112+
return None
113+
111114
if (end - start) != pred_len:
112115
planned = g_traj[start:end]
113116
len_diff = len(planned) - pred_len
114117

115118
if len_diff <= 0:
116-
planned = np.pad(
117-
planned, ((0, abs(len_diff)), (0, 0)), mode="edge"
118-
)
119+
try:
120+
planned = np.pad(
121+
planned, ((0, abs(len_diff)), (0, 0)), mode="edge"
122+
)
123+
except ValueError:
124+
raise ValueError(
125+
f"Padding went wrong for planned with shape {planned.shape} and len_diff {len_diff}"
126+
)
127+
119128
else:
120129
raise ValueError("Something went wrong")
121130
else:

proj/environment/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def _start_logging(self):
6060
filename = str(self.datafolder / f"{self.exp_name}.log")
6161
logger.add(filename)
6262
logger.info(make_header(f"Starting simulation at {timestamp()}"))
63-
logger.info("Saving data at: {self.datafolder}")
63+
logger.info(f"Saving data at: {self.datafolder}")
6464

6565
def _log_conf(self):
6666
# log config.py

proj/environment/trajectories.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ def compute_trajectory_stats(
5151
mute=False,
5252
):
5353

54-
# ! shorten traj
55-
trajectory = trajectory[: planning_params["prediction_length"] + 1, :]
56-
5754
# Compute stats
5855
n_points = len(trajectory)
5956
distance_travelled = np.sum(
@@ -140,11 +137,11 @@ def compute_trajectory_stats(
140137
print(
141138
f"[bold red]Lookahead of {lookahead} is {perc_lookahead} of the # of waypoints, that might be too low. Values closer to 5% are advised."
142139
)
143-
# if distance_travelled < min_dist_travelled * params["px_to_cm"]:
144-
# logger.warning(
145-
# "Distance travelled below minimal requirement, erroring"
146-
# )
147-
# return None, None
140+
if distance_travelled < min_dist_travelled * params["px_to_cm"]:
141+
logger.warning(
142+
"Distance travelled below minimal requirement, erroring"
143+
)
144+
return None, None
148145

149146
return trajectory, duration, metadata
150147

proj/plotting/results.py

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,37 +7,38 @@
77
from fcutils.plotting.plot_elements import plot_line_outlined
88
from fcutils.maths.utils import derivative
99

10-
from proj.utils.misc import load_results_from_folder
10+
from proj.utils.misc import load_results_from_folder, duration_from_history
1111
from proj.animation import variables_colors as colors
1212

1313

1414
def _make_figure():
15-
f = plt.figure(figsize=(20, 8))
15+
f = plt.figure(figsize=(20, 12))
1616

17-
gs = f.add_gridspec(2, 8)
17+
gs = f.add_gridspec(3, 6)
1818

19-
xy_ax = f.add_subplot(gs[:, :2])
19+
xy_ax = f.add_subplot(gs[:2, :2])
2020
xy_ax.axis("equal")
21-
xy_ax.axis("off")
21+
# xy_ax.axis("off")
2222

2323
tau_ax = f.add_subplot(gs[0, 2:4])
24-
sax = f.add_subplot(gs[1, 2:4])
25-
accel_ax = f.add_subplot(gs[0, 4:6])
26-
cost_ax = f.add_subplot(gs[1, 4:6])
24+
sax = f.add_subplot(gs[2, :2]) # speed trajectory
25+
# accel_ax = f.add_subplot(gs[0, 4:6])
26+
# cost_ax = f.add_subplot(gs[1, 4:6])
2727

28-
tau_int_ax = f.add_subplot(gs[0, 6:])
29-
acc_int_ax = f.add_subplot(gs[1, 6:])
28+
tau_int_ax = f.add_subplot(gs[0, 4:6])
29+
omega_ax = f.add_subplot(gs[1, 4:6])
30+
speed_ax = f.add_subplot(gs[1, 2:4])
3031

31-
return f, xy_ax, tau_ax, sax, accel_ax, cost_ax, tau_int_ax, acc_int_ax
32+
return f, xy_ax, tau_ax, sax, tau_int_ax, omega_ax, speed_ax
3233

3334

34-
def _plot_xy(history, trajectory, plot_every, ax=None):
35+
def _plot_xy(history, trajectory, plot_every, duration, ax=None):
3536
# plot trajectory
3637
plot_line_outlined(
3738
ax,
3839
trajectory[:, 0],
3940
trajectory[:, 1],
40-
lw=1.5,
41+
lw=2.5,
4142
color=colors["trajectory"],
4243
outline=0.5,
4344
outline_color="white",
@@ -63,6 +64,9 @@ def _plot_xy(history, trajectory, plot_every, ax=None):
6364
outline_color=[0.2, 0.2, 0.2],
6465
)
6566

67+
# Set ax properties
68+
ax.set(xlabel="cm", ylabel="cm", title=f"Duration: {duration}s")
69+
6670

6771
def _plot_control(history, ax=None):
6872
R, L = history["tau_r"], history["tau_l"]
@@ -85,6 +89,7 @@ def _plot_control(history, ax=None):
8589
solid_capstyle="round",
8690
)
8791
ax.legend()
92+
ax.set(xlabel="# frames", ylabel="Force", title="Control history")
8893

8994

9095
def _plot_v(history, trajectory, plot_every, ax=None):
@@ -108,6 +113,12 @@ def _plot_v(history, trajectory, plot_every, ax=None):
108113
history["trajectory_idx"], v, color=colors["v"], lw=3, zorder=100,
109114
)
110115

116+
ax.set(
117+
xlabel="Trajectory idx",
118+
ylabel="Speed (cm/s)",
119+
title="Speed trajectory",
120+
)
121+
111122

112123
def _plot_accel(history, ax=None):
113124
v, omega = history["v"], history["omega"]
@@ -134,7 +145,7 @@ def _plot_cost(cost_history, ax=None):
134145
ax.legend()
135146

136147

137-
def _plot_integrals(history, dt, tax=None, aax=None):
148+
def _plot_integrals(history, dt, tax=None, oax=None, sax=None):
138149
R, L = history["nudot_right"], history["nudot_left"]
139150

140151
plot_line_outlined(
@@ -153,54 +164,56 @@ def _plot_integrals(history, dt, tax=None, aax=None):
153164
lw=2,
154165
solid_capstyle="round",
155166
)
167+
tax.set(title="Wheels accelerations", xlabel="# Frames", ylabel="accel")
156168
tax.legend()
157169

158170
# plot v and omega
159171
v, omega = history["v"], history["omega"]
160172

161173
plot_line_outlined(
162-
aax,
174+
sax,
163175
v,
164176
color=desaturate_color(colors["v"]),
165177
label="$v$",
166178
lw=2,
167179
solid_capstyle="round",
168180
)
181+
sax.legend()
182+
sax.set(title="Running speed", xlabel="# frames", ylabel="$v$")
183+
169184
plot_line_outlined(
170-
aax,
185+
oax,
171186
omega,
172187
color=desaturate_color(colors["omega"]),
173188
label="$\\omega$",
174189
lw=2,
175190
solid_capstyle="round",
176191
)
177-
aax.legend()
192+
oax.legend()
193+
oax.set(title="Angular velocity", xlabel="# frames", ylabel="$\\omega$")
178194

179195

180196
def plot_results(results_folder, plot_every=20, save_path=None):
181197
config, trajectory, history, cost_history = load_results_from_folder(
182198
results_folder
183199
)
200+
duration = duration_from_history(history, config)
184201

185-
(
186-
f,
187-
xy_ax,
188-
tau_ax,
189-
sax,
190-
accel_ax,
191-
cost_ax,
192-
tau_int_ax,
193-
acc_int_ax,
194-
) = _make_figure()
202+
f, xy_ax, tau_ax, sax, tau_int_ax, omega_ax, speed_ax = _make_figure()
195203

196-
_plot_xy(history, trajectory, plot_every, ax=xy_ax)
204+
_plot_xy(history, trajectory, plot_every, duration, ax=xy_ax)
197205
_plot_control(history, ax=tau_ax)
198-
_plot_v(history, trajectory, plot_every, ax=sax)
199-
_plot_accel(history, ax=accel_ax)
200-
_plot_cost(cost_history, ax=cost_ax)
201-
_plot_integrals(history, config["dt"], tax=tau_int_ax, aax=acc_int_ax)
206+
_plot_v(
207+
history, trajectory, plot_every, ax=sax
208+
) # plot v against the trajectory
209+
# _plot_accel(history, ax=accel_ax)
210+
# _plot_cost(cost_history, ax=cost_ax)
211+
_plot_integrals(
212+
history, config["dt"], tax=tau_int_ax, oax=omega_ax, sax=speed_ax
213+
)
202214

203215
clean_axes(f)
216+
f.tight_layout()
204217

205218
if save_path is not None:
206219
save_figure(f, str(save_path))

proj/run/runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def run_experiment(
122122
try:
123123
environment.conclude()
124124
except:
125+
logger.info("Failed to run environment.conclude()")
125126
environment.failed()
126127
return
127128

proj/utils/misc.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ def traj_to_polar(traj):
6767

6868

6969
# ----------------------------------- Misc ----------------------------------- #
70+
def duration_from_history(history, config):
71+
nframes = len(history)
72+
return round(nframes * config["dt"], 3)
73+
74+
7075
def timeit(method):
7176
def timed(*args, **kw):
7277
ts = time.time()

run_allocentric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# plot_trajectory(env.reset())
2626

2727
# %%
28-
run_experiment(env, control, model, n_secs=0.5)
28+
run_experiment(env, control, model, n_secs=0.05)
2929

3030

3131
# %%

0 commit comments

Comments
 (0)