2222# https://arxiv.org/abs/2212.00541
2323
2424
25- # policy class for predictive sampling
2625class Policy :
27- # initialize policy
26+ """Policy class for predictive sampling."""
27+
2828 def __init__ (
2929 self ,
3030 naction : int ,
@@ -33,6 +33,15 @@ def __init__(
3333 interp : str = "zero" ,
3434 limits : np .array = None ,
3535 ):
36+ """Initialize policy class.
37+
38+ Args:
39+ naction (int): number of actions
40+ horizon (float): planning horizon (seconds)
41+ splinestep (float): interval length between spline points
42+ interp (str, optional): type of action interpolation. Defaults to "zero".
43+ limits (np.array, optional): lower and upper bounds on actions. Defaults to None.
44+ """
3645 self ._naction = naction
3746 self ._splinestep = splinestep
3847 self ._horizon = horizon
@@ -44,8 +53,16 @@ def __init__(
4453 self ._interp = interp
4554 self ._limits = limits
4655
47- # find interval containing value
4856 def _find_interval (self , sequence : np .array , value : float ) -> [int , int ]:
57+ """Find neighboring indices in sequence containing value.
58+
59+ Args:
60+ sequence (np.array): array of values
61+ value (float): value to find in interval
62+
63+ Returns:
64+ [int, int]: lower and upper indices in sequence containing value
65+ """
4966 # bisection search to get interval
5067 upper = bisect .bisect_right (sequence , value )
5168 lower = upper - 1
@@ -60,8 +77,17 @@ def _find_interval(self, sequence: np.array, value: float) -> [int, int]:
6077 return (L - 1 , L - 1 )
6178 return (max (lower , 0 ), min (upper , L - 1 ))
6279
63- # compute slope at value
6480 def _slope (self , input : np .array , output : np .array , value : float ) -> np .array :
81+ """Compute interpolated slope vector at value.
82+
83+ Args:
84+ input (np.array): scalar sequence of inputs
85+ output (np.array): vector sequence of outputs
86+ value (float): input where to compute slope
87+
88+ Returns:
89+ np.array: interpolated slope vector
90+ """
6591 # bounds
6692 bounds = self ._find_interval (input , value )
6793
@@ -97,8 +123,15 @@ def _slope(self, input: np.array, output: np.array, value: float) -> np.array:
97123 input [bounds [0 ]] - input [bounds [0 ] - 1 ]
98124 )
99125
100- # get action from policy
101126 def action (self , time : float ) -> np .array :
127+ """Return action from policy at time.
128+
129+ Args:
130+ time (float): time value to evaluate plan for action
131+
132+ Returns:
133+ np.array: interpolated action at time
134+ """
102135 # find interval containing time
103136 bounds = self ._find_interval (self ._times , time )
104137
@@ -141,8 +174,12 @@ def action(self, time: float) -> np.array:
141174 else : # self._interp == "zero"
142175 return self .clamp (self ._parameters [:, bounds [0 ]])
143176
144- # resample policy plan from current time
145177 def resample (self , time : float ):
178+ """Resample plan starting from time.
179+
180+ Args:
181+ time (float): time value to start updated plan
182+ """
146183 # new times and parameters
147184 times = np .array (
148185 [i * self ._splinestep + time for i in range (self ._nspline )], dtype = float
@@ -153,16 +190,27 @@ def resample(self, time: float):
153190 self ._times = times
154191 self ._parameters = parameters
155192
156- # add zero-mean Gaussian noise to policy parameters
157193 def add_noise (self , scale : float ):
194+ """Add zero-mean Gaussian noise to plan.
195+
196+ Args:
197+ scale (float): standard deviation of zero-mean Gaussian noise
198+ """
158199 # clamp within limits
159200 self ._parameters = self .clamp (
160201 self ._parameters
161202 + np .random .normal (scale = scale , size = (self ._naction , self ._nspline ))
162203 )
163204
164- # return a copy of the policy with noisy parameters
165205 def noisy_copy (self , scale : float ) -> Policy :
206+ """Return a copy of plan with added noise.
207+
208+ Args:
209+ scale (float): standard deviation of zero-mean Gaussian noise
210+
211+ Returns:
212+ Policy: copy of object with noisy plan
213+ """
166214 # create new policy object
167215 policy = Policy (self ._naction , self ._horizon , self ._splinestep )
168216
@@ -174,8 +222,15 @@ def noisy_copy(self, scale: float) -> Policy:
174222
175223 return policy
176224
177- # clamp action with limits
178225 def clamp (self , action : np .array ) -> np .array :
226+ """Return input clamped between limits.
227+
228+ Args:
229+ action (np.array): input vector
230+
231+ Returns:
232+ np.array: clamped input vector
233+ """
179234 # clamp within limits
180235 if self ._limits is not None :
181236 return np .minimum (
@@ -184,7 +239,6 @@ def clamp(self, action: np.array) -> np.array:
184239 return action
185240
186241
187- # rollout
188242def rollout (
189243 qpos : np .array ,
190244 qvel : np .array ,
@@ -198,6 +252,24 @@ def rollout(
198252 policy : Policy ,
199253 horizon : float ,
200254) -> float :
255+ """Return total return by rollout out plan with forward dynamics.
256+
257+ Args:
258+ qpos (np.array): initial configuration
259+ qvel (np.array): initial velocity
260+ act (np.array): initial activation
261+ time (float): current time
262+ mocap_pos (np.array): motion-capture body positions
263+ mocap_quat (np.array): motion-capture body orientations
264+ model (mujoco.MjModel): MuJoCo model
265+ data (mujoco.MjData): MuJoCo data
266+ reward (function): function returning per-timestep reward value
267+ policy (Policy): plan for computing action at given time
268+ horizon (float): planning duration (seconds)
269+
270+ Returns:
271+ float: total return (normalized by number of planning steps)
272+ """
201273 # number of steps
202274 steps = int (horizon / model .opt .timestep )
203275
@@ -233,9 +305,9 @@ def rollout(
233305 return total_reward / (steps + 1 )
234306
235307
236- # predictive sampling planner class
237308class Planner :
238- # initialize planner
309+ """Predictive sampling planner class."""
310+
239311 def __init__ (
240312 self ,
241313 model : mujoco .MjModel ,
@@ -249,6 +321,20 @@ def __init__(
249321 interp : str = "zero" ,
250322 limits : bool = True ,
251323 ):
324+ """Initialize planner.
325+
326+ Args:
327+ model (mujoco.MjModel): MuJoCo model
328+ reward (function): function returning per-timestep reward value
329+ horizon (float): planning duration (seconds)
330+ splinestep (float): interval length between spline points
331+ planstep (float): interval length between forward dynamics steps
332+ nsample (int): number of noisy plans to evaluate
333+ noise_scale (float): standard deviation of zero-mean Gaussian
334+ nimprove (int): number of iterations to improve plan for fixed initial state
335+ interp (str, optional): type of action interpolation. Defaults to "zero".
336+ limits (bool, optional): lower and upper bounds on action. Defaults to True.
337+ """
252338 self ._model = model .__copy__ ()
253339 self ._model .opt .timestep = planstep
254340 self ._data = mujoco .MjData (self ._model )
@@ -265,11 +351,17 @@ def __init__(
265351 self ._noise_scale = noise_scale
266352 self ._nimprove = nimprove
267353
268- # action from policy
269354 def action_from_policy (self , time : float ) -> np .array :
355+ """Return action at time from policy.
356+
357+ Args:
358+ time (float): time to evaluate plan for action
359+
360+ Returns:
361+ np.array: action interpolation at time
362+ """
270363 return self .policy .action (time )
271364
272- # improve policy
273365 def improve_policy (
274366 self ,
275367 qpos : np .array ,
@@ -279,6 +371,16 @@ def improve_policy(
279371 mocap_pos : np .array ,
280372 mocap_quat : np .array ,
281373 ):
374+ """Iteratively improve plan via searching noisy plans.
375+
376+ Args:
377+ qpos (np.array): initial configuration
378+ qvel (np.array): initial velocity
379+ act (np.array): initial activation
380+ time (float): current time
381+ mocap_pos (np.array): motion-capture body position
382+ mocap_quat (np.array): motion-capture body orientation
383+ """
282384 # resample
283385 self .policy .resample (time )
284386
0 commit comments