Skip to content

Commit 8dec503

Browse files
committed
Attempt to fix examples/solo_bend.py
1 parent ba78dd2 commit 8dec503

File tree

1 file changed

+54
-38
lines changed

1 file changed

+54
-38
lines changed

examples/solo_bend.py

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,15 @@
44
import coal
55

66
from aligator import manifolds, dynamics
7-
from pinocchio.visualize import MeshcatVisualizer
8-
from pinocchio.visualize.meshcat_visualizer import COLOR_PRESETS
7+
from utils import ArgsBase
98
from utils.solo import rmodel, rdata, robot, q0, create_ground_contact_model
109

11-
COLOR_PRESETS["white"] = ([1, 1, 1], [1, 1, 1])
1210

11+
class Args(ArgsBase):
12+
pass
1313

14-
vizer = MeshcatVisualizer(
15-
rmodel,
16-
collision_model=robot.collision_model,
17-
visual_model=robot.visual_model,
18-
data=rdata,
19-
)
14+
15+
args = Args().parse_args()
2016

2117
pin.framesForwardKinematics(rmodel, rdata, q0)
2218

@@ -38,15 +34,23 @@ def define_dynamics():
3834

3935

4036
ode = define_dynamics()
41-
timestep = 0.01
37+
timestep = 0.02
4238
Tf = 4.0
4339
nsteps = int(Tf / timestep)
4440

4541
dyn_model = dynamics.IntegratorSemiImplEuler(ode, timestep)
4642

47-
u0 = np.zeros(nu)
4843
v0 = np.zeros(nv)
4944
x0 = np.concatenate([q0, v0])
45+
u0, _ = aligator.underactuatedConstrainedInverseDynamics(
46+
rmodel,
47+
rdata,
48+
q0,
49+
v0,
50+
act_matrix,
51+
ode.constraint_models,
52+
[cm.createData() for cm in ode.constraint_models],
53+
)
5054

5155

5256
def create_target(i: int):
@@ -70,62 +74,74 @@ def update_target(sphere, x):
7074

7175

7276
# Define cost functions
73-
base_weight = 2.0
74-
w_xreg = np.diag([1e-3] * nv + [1e-3] * nv)
75-
w_xreg[range(3), range(3)] = base_weight
77+
print(nv)
78+
w_xreg = [0.1] * 3 + [0.01] * (nv - 3) + [1e-4] * nv
79+
w_xreg = np.diag(w_xreg)
80+
print(w_xreg)
7681

77-
w_ureg = np.eye(nu) * 1e-3
78-
ureg_cost = aligator.QuadraticControlCost(space, u0, w_ureg * timestep)
82+
w_ureg = np.eye(nu) * 1e-4
83+
u_cost = aligator.QuadraticControlCost(space, u0, w_ureg * timestep)
7984

8085
stages = []
8186
for i in range(nsteps):
8287
x_cost = aligator.QuadraticStateCost(space, nu, X_TARGETS[i], w_xreg * timestep)
8388
rcost = aligator.CostStack(space, nu)
8489
rcost.addCost(x_cost)
85-
rcost.addCost(ureg_cost)
90+
rcost.addCost(u_cost)
8691
stm = aligator.StageModel(rcost, dyn_model)
8792
stages.append(stm)
8893

89-
w_xterm = np.diag([1e-3] * nv + [1e-3] * nv)
90-
w_xterm[range(3), range(3)] = base_weight
94+
w_xterm = 1.0 * np.eye(space.ndx)
9195
xreg_term = aligator.QuadraticStateCost(space, nu, X_TARGETS[nsteps], w_xterm)
9296
term_cost = xreg_term
9397

9498

9599
def main():
96-
vizer.initViewer(loadModel=True, open=True)
97-
vizer.display(q0)
98-
vizer.setBackgroundColor("white")
100+
if args.display:
101+
from pinocchio.visualize import MeshcatVisualizer
102+
from pinocchio.visualize.meshcat_visualizer import COLOR_PRESETS
103+
104+
COLOR_PRESETS["white"] = ([1, 1, 1], [1, 1, 1])
105+
sphere = pin.GeometryObject("target", 0, pin.SE3.Identity(), coal.Sphere(0.01))
106+
sphere.meshColor[:] = 217, 101, 38, 120
107+
sphere.meshColor /= 255.0
108+
robot.visual_model.addGeometryObject(sphere)
109+
vizer = MeshcatVisualizer(
110+
rmodel,
111+
collision_model=robot.collision_model,
112+
visual_model=robot.visual_model,
113+
data=rdata,
114+
)
115+
vizer.initViewer(loadModel=True, open=True)
116+
vizer.display(q0)
117+
vizer.setBackgroundColor("white")
99118

100-
# display target as a transparent sphere
101-
sphere = pin.GeometryObject("target", 0, pin.SE3.Identity(), coal.Sphere(0.01))
102-
sphere.meshColor[:] = 217, 101, 38, 120
103-
sphere.meshColor /= 255.0
104-
vizer.addGeometryObject(sphere)
105-
106-
xs_i = [x0] * (nsteps + 1)
107119
us_i = [u0] * nsteps
120+
xs_i = [x0] * (nsteps + 1)
108121

109122
problem = aligator.TrajOptProblem(x0, stages, term_cost)
110123

111124
mu_init = 1e-2
112-
solver = aligator.SolverProxDDP(1e-3, mu_init, verbose=aligator.VERBOSE)
113-
solver.reg_init = 1e-8
125+
tol = 1e-4
126+
solver = aligator.SolverProxDDP(tol, mu_init, verbose=aligator.VERBOSE)
127+
# solver.sa_strategy = aligator.SA_FILTER
114128
solver.setup(problem)
115129
flag = solver.run(problem, xs_i, us_i)
116130
print(flag)
117131

118132
rs = solver.results
119133
qs_opt = [x[:nq] for x in rs.xs]
120-
input("[display?]")
121134

122-
def callback(i: int):
123-
update_target(sphere, X_TARGETS[i])
135+
if args.display:
136+
input("[display?]")
137+
138+
def callback(i: int):
139+
update_target(sphere, X_TARGETS[i])
124140

125-
NR = 3
126-
for _ in range(NR):
127-
vizer.play(qs_opt, timestep, callback=callback)
128-
input()
141+
num_repeat = 3
142+
for _ in range(num_repeat):
143+
vizer.play(qs_opt, timestep, callback=callback)
144+
input()
129145

130146

131147
if __name__ == "__main__":

0 commit comments

Comments
 (0)