Skip to content

Commit 48b8b5f

Browse files
authored
Merge pull request #292 from thowell/direct_python_pin
Add qpos pinning to direct optimizer (python)
2 parents bd6eda6 + fcb43a4 commit 48b8b5f

File tree

7 files changed

+494
-760
lines changed

7 files changed

+494
-760
lines changed
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
# Copyright 2023 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import matplotlib.pyplot as plt
16+
import mediapy as media
17+
import mujoco
18+
from mujoco_mpc import direct as direct_lib
19+
import numpy as np
20+
21+
22+
# %%
23+
# 2D Particle Model
24+
xml = """
25+
<mujoco model="Particle">
26+
<visual>
27+
<headlight ambient=".4 .4 .4" diffuse=".8 .8 .8" specular="0.1 0.1 0.1"/>
28+
<map znear=".01"/>
29+
<quality shadowsize="2048"/>
30+
<global elevation="-15"/>
31+
</visual>
32+
33+
<asset>
34+
<texture name="blue_grid" type="2d" builtin="checker" rgb1=".02 .14 .44" rgb2=".27 .55 1" width="300" height="300" mark="edge" markrgb="1 1 1"/>
35+
<material name="blue_grid" texture="blue_grid" texrepeat="1 1" texuniform="true" reflectance=".2"/>
36+
<texture name="skybox" type="skybox" builtin="gradient" rgb1=".66 .79 1" rgb2=".9 .91 .93" width="800" height="800"/>
37+
<material name="self" rgba=".7 .5 .3 1"/>
38+
<material name="decoration" rgba=".2 .6 .3 1"/>
39+
</asset>
40+
41+
<option timestep="0.01"></option>
42+
43+
<default>
44+
<joint type="hinge" axis="0 0 1" limited="true" range="-.29 .29" damping="1"/>
45+
<motor gear=".1" ctrlrange="-1 1" ctrllimited="true"/>
46+
</default>
47+
48+
<worldbody>
49+
<light name="light" pos="0 0 1"/>
50+
<camera name="fixed" pos="0 0 .75" quat="1 0 0 0"/>
51+
<geom name="ground" type="plane" pos="0 0 0" size=".3 .3 .1" material="blue_grid"/>
52+
<geom name="wall_x" type="plane" pos="-.3 0 .02" zaxis="1 0 0" size=".02 .3 .02" material="decoration"/>
53+
<geom name="wall_y" type="plane" pos="0 -.3 .02" zaxis="0 1 0" size=".3 .02 .02" material="decoration"/>
54+
<geom name="wall_neg_x" type="plane" pos=".3 0 .02" zaxis="-1 0 0" size=".02 .3 .02" material="decoration"/>
55+
<geom name="wall_neg_y" type="plane" pos="0 .3 .02" zaxis="0 -1 0" size=".3 .02 .02" material="decoration"/>
56+
57+
<body name="pointmass" pos="0 0 .01">
58+
<camera name="cam0" pos="0 -0.3 0.3" xyaxes="1 0 0 0 0.7 0.7"/>
59+
<joint name="root_x" type="slide" pos="0 0 0" axis="1 0 0" />
60+
<joint name="root_y" type="slide" pos="0 0 0" axis="0 1 0" />
61+
<geom name="pointmass" type="sphere" size=".01" material="self" mass=".3"/>
62+
<site name="tip" pos="0 0 0" size="0.01"/>
63+
</body>
64+
</worldbody>
65+
66+
<actuator>
67+
<motor name="x_motor" joint="root_x" gear="1" ctrllimited="true" ctrlrange="-1 1"/>
68+
<motor name="y_motor" joint="root_y" gear="1" ctrllimited="true" ctrlrange="-1 1"/>
69+
</actuator>
70+
71+
<sensor>
72+
<jointpos name="x" joint="root_x" />
73+
<jointpos name="y" joint="root_y" />
74+
</sensor>
75+
</mujoco>
76+
"""
77+
78+
model = mujoco.MjModel.from_xml_string(xml)
79+
data = mujoco.MjData(model)
80+
renderer = mujoco.Renderer(model)
81+
# %%
82+
# initialization
83+
T = 100
84+
q0 = np.array([-0.25, -0.25])
85+
qM = np.array([-0.25, 0.25])
86+
qN = np.array([0.25, -0.25])
87+
qT = np.array([0.25, 0.25])
88+
89+
# compute linear interpolation
90+
qinterp = np.zeros((model.nq, T))
91+
for t in range(T):
92+
# slope
93+
slope = (qT - q0) / T
94+
95+
# interpolation
96+
qinterp[:, t] = q0 + t * slope
97+
98+
# time
99+
time = [t * model.opt.timestep for t in range(T)]
100+
# %%
101+
# plot position
102+
fig = plt.figure()
103+
104+
# arm position
105+
plt.plot(qinterp[0, :], qinterp[1, :], label="interpolation", color="black")
106+
plt.plot(q0[0], q0[1], color="magenta", label="waypoint", marker="o")
107+
plt.plot(qM[0], qM[1], color="magenta", marker="o")
108+
plt.plot(qN[0], qN[1], color="magenta", marker="o")
109+
plt.plot(qT[0], qT[1], color="magenta", marker="o")
110+
111+
plt.legend()
112+
plt.xlabel("X")
113+
plt.ylabel("Y")
114+
# %%
115+
# optimizer model
116+
model_optimizer = mujoco.MjModel.from_xml_string(xml)
117+
118+
# direct optimizer
119+
configuration_length = T + 2
120+
optimizer = direct_lib.Direct(
121+
model=model_optimizer,
122+
configuration_length=configuration_length,
123+
)
124+
# %%
125+
# set data
126+
for t in range(configuration_length):
127+
# unpack
128+
qt = np.zeros(model.nq)
129+
st = np.zeros(model.nsensordata)
130+
mt = np.zeros(model.nsensor)
131+
ft = np.zeros(model.nv)
132+
ct = np.zeros(model.nu)
133+
tt = np.array([t * model.opt.timestep])
134+
135+
# set initial state
136+
if t == 0 or t == 1:
137+
qt = q0
138+
st = q0
139+
mt = np.array([1, 1])
140+
141+
# set goal
142+
elif t >= configuration_length - 2:
143+
qt = qT
144+
st = qT
145+
mt = np.array([1, 1])
146+
147+
# set waypoint
148+
elif t == 25:
149+
st = qM
150+
mt = np.array([1, 1])
151+
152+
# set waypoint
153+
elif t == 75:
154+
st = qN
155+
mt = np.array([1, 1])
156+
157+
# initialize qpos
158+
else:
159+
qt = qinterp[:, t - 1]
160+
mt = np.array([0, 0])
161+
162+
# set data
163+
data_ = optimizer.data(
164+
t,
165+
configuration=qt,
166+
sensor_measurement=st,
167+
sensor_mask=mt,
168+
force_measurement=ft,
169+
time=tt,
170+
)
171+
# %%
172+
# set std
173+
optimizer.noise(process=np.array([1000.0, 1000.0]), sensor=np.array([1.0, 1.0]))
174+
175+
# set settings
176+
optimizer.settings(
177+
sensor_flag=True,
178+
force_flag=True,
179+
max_smoother_iterations=1000,
180+
max_search_iterations=1000,
181+
regularization_initial=1.0e-12,
182+
gradient_tolerance=1.0e-6,
183+
search_direction_tolerance=1.0e-6,
184+
cost_tolerance=1.0e-6,
185+
first_step_position_sensors=True,
186+
last_step_position_sensors=True,
187+
last_step_velocity_sensors=True,
188+
)
189+
190+
# optimize
191+
optimizer.optimize()
192+
193+
# costs
194+
optimizer.print_cost()
195+
196+
# status
197+
optimizer.print_status()
198+
# %%
199+
# get estimated trajectories
200+
q_est = np.zeros((model_optimizer.nq, configuration_length))
201+
v_est = np.zeros((model_optimizer.nv, configuration_length))
202+
s_est = np.zeros((model_optimizer.nsensordata, configuration_length))
203+
f_est = np.zeros((model_optimizer.nv, configuration_length))
204+
t_est = np.zeros(configuration_length)
205+
for t in range(configuration_length):
206+
data_ = optimizer.data(t)
207+
q_est[:, t] = data_["configuration"]
208+
v_est[:, t] = data_["velocity"]
209+
s_est[:, t] = data_["sensor_prediction"]
210+
f_est[:, t] = data_["force_prediction"]
211+
t_est[t] = data_["time"]
212+
# %%
213+
# plot position
214+
fig = plt.figure()
215+
216+
plt.plot(qinterp[0, :], qinterp[1, :], label="interpolation", color="black")
217+
plt.plot(q_est[0, :], q_est[1, :], label="direct trajopt", color="orange")
218+
plt.plot(q0[0], q0[1], color="magenta", label="waypoint", marker="o")
219+
plt.plot(qM[0], qM[1], color="magenta", marker="o")
220+
plt.plot(qN[0], qN[1], color="magenta", marker="o")
221+
plt.plot(qT[0], qT[1], color="magenta", marker="o")
222+
223+
plt.legend()
224+
plt.xlabel("X")
225+
plt.ylabel("Y")
226+
227+
# plot velocity
228+
fig = plt.figure()
229+
230+
# velocity
231+
plt.plot(t_est[1:] - model.opt.timestep, v_est[0, 1:], label="v0", color="cyan")
232+
plt.plot(
233+
t_est[1:] - model.opt.timestep, v_est[1, 1:], label="v1", color="orange"
234+
)
235+
236+
plt.legend()
237+
plt.xlabel("Time (s)")
238+
plt.ylabel("Velocity")
239+
# %%
240+
# frames optimized
241+
frames_opt = []
242+
243+
# simulate
244+
for t in range(configuration_length - 1):
245+
# get solution from optimizer
246+
data_ = optimizer.data(t)
247+
248+
# set configuration
249+
data.qpos = q_est[:, t]
250+
data.qvel = v_est[:, t]
251+
252+
mujoco.mj_forward(model, data)
253+
254+
# render and save frames
255+
renderer.update_scene(data)
256+
pixels = renderer.render()
257+
frames_opt.append(pixels)
258+
259+
# display video
260+
# media.show_video(frames_opt, fps=1.0 / model.opt.timestep, loop=False)

0 commit comments

Comments
 (0)