@@ -277,12 +277,15 @@ def update(self, cache_update: StateSamplerCacheUpdate) -> None:
277277 self ._cache [new_cache_key ] = new_cache_entry
278278
279279
280- def clear_illegal_actions ( trajectory : tuple [int ], rewards : tuple [float ] ) -> tuple [ tuple [int ], tuple [float ]]:
280+ def clear_illegal_actions (
281+ trajectory : tuple [int ], rewards : tuple [float ]
282+ ) -> tuple [tuple [int ], tuple [float ]]:
281283 filtered = [(t , r ) for t , r in zip (trajectory , rewards ) if r > - 0.101 ]
282284 # Unzip the filtered values into separate tuples
283285 new_trajectory , new_rewards = zip (* filtered ) if filtered else ((), ())
284286 return new_trajectory , new_rewards
285-
287+
288+
286289def rollout (
287290 rollout_params : RolloutParams ,
288291 start_current_best_trajectory_length : int ,
@@ -321,14 +324,18 @@ def rollout(
321324 # the validity of the rest of the trajectory
322325 # because illegal actions cost reward but do not
323326 # change state.
324- current_trajectory , rewards = clear_illegal_actions (current_trajectory , rewards )
325- success_entry = SuccessEntry (trajectory = current_trajectory , rewards = rewards )
327+ current_trajectory , rewards = clear_illegal_actions (
328+ current_trajectory , rewards
329+ )
330+ success_entry = SuccessEntry (
331+ trajectory = current_trajectory , rewards = rewards
332+ )
326333 if success_entry not in success_entries :
327334 success_entries .add (
328335 SuccessEntry (trajectory = current_trajectory , rewards = rewards )
329336 )
330337 led_to_something_new = True
331-
338+
332339 state_sampler_cache_update .update_current_best_trajectory (
333340 len (current_trajectory )
334341 )
0 commit comments