Skip to content

Commit 26091ef

Browse files
committed
add docstrings
1 parent d8dd9ae commit 26091ef

File tree

1 file changed

+116
-14
lines changed

1 file changed

+116
-14
lines changed

python/mujoco_mpc/demos/predictive_sampling/predictive_sampling.py

Lines changed: 116 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
# https://arxiv.org/abs/2212.00541
2323

2424

25-
# policy class for predictive sampling
2625
class Policy:
27-
# initialize policy
26+
"""Policy class for predictive sampling."""
27+
2828
def __init__(
2929
self,
3030
naction: int,
@@ -33,6 +33,15 @@ def __init__(
3333
interp: str = "zero",
3434
limits: np.array = None,
3535
):
36+
"""Initialize policy class.
37+
38+
Args:
39+
naction (int): number of actions
40+
horizon (float): planning horizon (seconds)
41+
splinestep (float): interval length between spline points
42+
interp (str, optional): type of action interpolation. Defaults to "zero".
43+
limits (np.array, optional): lower and upper bounds on actions. Defaults to None.
44+
"""
3645
self._naction = naction
3746
self._splinestep = splinestep
3847
self._horizon = horizon
@@ -44,8 +53,16 @@ def __init__(
4453
self._interp = interp
4554
self._limits = limits
4655

47-
# find interval containing value
4856
def _find_interval(self, sequence: np.array, value: float) -> [int, int]:
57+
"""Find neighboring indices in sequence containing value.
58+
59+
Args:
60+
sequence (np.array): array of values
61+
value (float): value to find in interval
62+
63+
Returns:
64+
[int, int]: lower and upper indices in sequence containing value
65+
"""
4966
# bisection search to get interval
5067
upper = bisect.bisect_right(sequence, value)
5168
lower = upper - 1
@@ -60,8 +77,17 @@ def _find_interval(self, sequence: np.array, value: float) -> [int, int]:
6077
return (L - 1, L - 1)
6178
return (max(lower, 0), min(upper, L - 1))
6279

63-
# compute slope at value
6480
def _slope(self, input: np.array, output: np.array, value: float) -> np.array:
81+
"""Compute interpolated slope vector at value.
82+
83+
Args:
84+
input (np.array): scalar sequence of inputs
85+
output (np.array): vector sequence of outputs
86+
value (float): input where to compute slope
87+
88+
Returns:
89+
np.array: interpolated slope vector
90+
"""
6591
# bounds
6692
bounds = self._find_interval(input, value)
6793

@@ -97,8 +123,15 @@ def _slope(self, input: np.array, output: np.array, value: float) -> np.array:
97123
input[bounds[0]] - input[bounds[0] - 1]
98124
)
99125

100-
# get action from policy
101126
def action(self, time: float) -> np.array:
127+
"""Return action from policy at time.
128+
129+
Args:
130+
time (float): time value to evaluate plan for action
131+
132+
Returns:
133+
np.array: interpolated action at time
134+
"""
102135
# find interval containing time
103136
bounds = self._find_interval(self._times, time)
104137

@@ -141,8 +174,12 @@ def action(self, time: float) -> np.array:
141174
else: # self._interp == "zero"
142175
return self.clamp(self._parameters[:, bounds[0]])
143176

144-
# resample policy plan from current time
145177
def resample(self, time: float):
178+
"""Resample plan starting from time.
179+
180+
Args:
181+
time (float): time value to start updated plan
182+
"""
146183
# new times and parameters
147184
times = np.array(
148185
[i * self._splinestep + time for i in range(self._nspline)], dtype=float
@@ -153,16 +190,27 @@ def resample(self, time: float):
153190
self._times = times
154191
self._parameters = parameters
155192

156-
# add zero-mean Gaussian noise to policy parameters
157193
def add_noise(self, scale: float):
194+
"""Add zero-mean Gaussian noise to plan.
195+
196+
Args:
197+
scale (float): standard deviation of zero-mean Gaussian noise
198+
"""
158199
# clamp within limits
159200
self._parameters = self.clamp(
160201
self._parameters
161202
+ np.random.normal(scale=scale, size=(self._naction, self._nspline))
162203
)
163204

164-
# return a copy of the policy with noisy parameters
165205
def noisy_copy(self, scale: float) -> Policy:
206+
"""Return a copy of plan with added noise.
207+
208+
Args:
209+
scale (float): standard deviation of zero-mean Gaussian noise
210+
211+
Returns:
212+
Policy: copy of object with noisy plan
213+
"""
166214
# create new policy object
167215
policy = Policy(self._naction, self._horizon, self._splinestep)
168216

@@ -174,8 +222,15 @@ def noisy_copy(self, scale: float) -> Policy:
174222

175223
return policy
176224

177-
# clamp action with limits
178225
def clamp(self, action: np.array) -> np.array:
226+
"""Return input clamped between limits.
227+
228+
Args:
229+
action (np.array): input vector
230+
231+
Returns:
232+
np.array: clamped input vector
233+
"""
179234
# clamp within limits
180235
if self._limits is not None:
181236
return np.minimum(
@@ -184,7 +239,6 @@ def clamp(self, action: np.array) -> np.array:
184239
return action
185240

186241

187-
# rollout
188242
def rollout(
189243
qpos: np.array,
190244
qvel: np.array,
@@ -198,6 +252,24 @@ def rollout(
198252
policy: Policy,
199253
horizon: float,
200254
) -> float:
255+
"""Return total return by rollout out plan with forward dynamics.
256+
257+
Args:
258+
qpos (np.array): initial configuration
259+
qvel (np.array): initial velocity
260+
act (np.array): initial activation
261+
time (float): current time
262+
mocap_pos (np.array): motion-capture body positions
263+
mocap_quat (np.array): motion-capture body orientations
264+
model (mujoco.MjModel): MuJoCo model
265+
data (mujoco.MjData): MuJoCo data
266+
reward (function): function returning per-timestep reward value
267+
policy (Policy): plan for computing action at given time
268+
horizon (float): planning duration (seconds)
269+
270+
Returns:
271+
float: total return (normalized by number of planning steps)
272+
"""
201273
# number of steps
202274
steps = int(horizon / model.opt.timestep)
203275

@@ -233,9 +305,9 @@ def rollout(
233305
return total_reward / (steps + 1)
234306

235307

236-
# predictive sampling planner class
237308
class Planner:
238-
# initialize planner
309+
"""Predictive sampling planner class."""
310+
239311
def __init__(
240312
self,
241313
model: mujoco.MjModel,
@@ -249,6 +321,20 @@ def __init__(
249321
interp: str = "zero",
250322
limits: bool = True,
251323
):
324+
"""Initialize planner.
325+
326+
Args:
327+
model (mujoco.MjModel): MuJoCo model
328+
reward (function): function returning per-timestep reward value
329+
horizon (float): planning duration (seconds)
330+
splinestep (float): interval length between spline points
331+
planstep (float): interval length between forward dynamics steps
332+
nsample (int): number of noisy plans to evaluate
333+
noise_scale (float): standard deviation of zero-mean Gaussian
334+
nimprove (int): number of iterations to improve plan for fixed initial state
335+
interp (str, optional): type of action interpolation. Defaults to "zero".
336+
limits (bool, optional): lower and upper bounds on action. Defaults to True.
337+
"""
252338
self._model = model.__copy__()
253339
self._model.opt.timestep = planstep
254340
self._data = mujoco.MjData(self._model)
@@ -265,11 +351,17 @@ def __init__(
265351
self._noise_scale = noise_scale
266352
self._nimprove = nimprove
267353

268-
# action from policy
269354
def action_from_policy(self, time: float) -> np.array:
355+
"""Return action at time from policy.
356+
357+
Args:
358+
time (float): time to evaluate plan for action
359+
360+
Returns:
361+
np.array: action interpolation at time
362+
"""
270363
return self.policy.action(time)
271364

272-
# improve policy
273365
def improve_policy(
274366
self,
275367
qpos: np.array,
@@ -279,6 +371,16 @@ def improve_policy(
279371
mocap_pos: np.array,
280372
mocap_quat: np.array,
281373
):
374+
"""Iteratively improve plan via searching noisy plans.
375+
376+
Args:
377+
qpos (np.array): initial configuration
378+
qvel (np.array): initial velocity
379+
act (np.array): initial activation
380+
time (float): current time
381+
mocap_pos (np.array): motion-capture body position
382+
mocap_quat (np.array): motion-capture body orientation
383+
"""
282384
# resample
283385
self.policy.resample(time)
284386

0 commit comments

Comments
 (0)