Skip to content

Commit af24474

Browse files
committed
plotting.py : get legend first before adding other labels
1 parent 5a2ca4c commit af24474

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

bindings/python/aligator/utils/plotting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def plot_convergence(
2121
dual_infeas.append(res.dual_infeas)
2222
plot_pd_errs(ax, prim_infeas, dual_infeas)
2323
ax.grid(axis="y", which="major")
24-
labels = [
25-
"$\\epsilon_\\mathrm{tol}$",
24+
handles, labels = ax.get_legend_handles_labels()
25+
labels += [
2626
"Prim. err $p$",
2727
"Dual err $d$",
2828
]
@@ -42,7 +42,7 @@ def plot_convergence(
4242
ax.vlines(al_change_idx, *ax.get_ylim(), colors="gray", lw=4.0, alpha=0.5)
4343

4444
ax.legend(labels=labels, **legend_kwargs)
45-
return
45+
return labels
4646

4747

4848
def plot_se2_pose(

examples/solo_jump.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test():
6565
test()
6666

6767

68-
dt = 10e-3 # 20 ms
68+
dt = 5e-3 # 20 ms
6969
tf = 1.2 # in seconds
7070
nsteps = int(tf / dt)
7171
print("Num steps: {:d}".format(nsteps))
@@ -79,7 +79,7 @@ def test():
7979
mask = (switch_t0 <= times) & (times < switch_t1)
8080

8181
q1 = q0.copy()
82-
q1[3:7] = pin.exp3_quat(np.array([0.0, 0.0, np.pi / 3]))
82+
# q1[3:7] = pin.exp3_quat(np.array([0.0, 0.0, np.pi / 3]))
8383
v0 = np.zeros(nv)
8484
x0_ref = np.concatenate((q0, v0))
8585
w_x = np.ones(space.ndx) * 1e-2
@@ -157,7 +157,7 @@ def create_land_cost(costs, w):
157157

158158
problem = aligator.TrajOptProblem(x0_ref, stages, term_cost)
159159
mu_init = 1e-5
160-
tol = 1e-4
160+
tol = 1e-5
161161
solver = aligator.SolverProxDDP(tol, mu_init, verbose=aligator.VERBOSE, max_iters=300)
162162
solver.rollout_type = aligator.ROLLOUT_LINEAR
163163
solver.setNumThreads(args.num_threads)
@@ -201,6 +201,7 @@ def make_plots(res: aligator.Results):
201201

202202
fig4 = plt.figure()
203203
ax = fig4.add_subplot(111)
204+
ax.hlines(tol, 0, res.num_iters, lw=2.2, alpha=0.8, colors="k")
204205
plot_convergence(cb_, ax, res, show_al_iters=True)
205206
fig4.tight_layout()
206207

0 commit comments

Comments
 (0)