@@ -330,9 +330,9 @@ def rollout_consistency_assertion(
330330):
331331 """Tests that observations in "next" match observations in the next root tensordict when done is False, and don't match otherwise."""
332332
333- done = rollout [: , :- 1 ]["next" , done_key ].squeeze (- 1 )
333+ done = rollout [... , :- 1 ]["next" , done_key ].squeeze (- 1 )
334334 # data resulting from step, when it's not done
335- r_not_done = rollout [: , :- 1 ]["next" ][~ done ]
335+ r_not_done = rollout [... , :- 1 ]["next" ][~ done ]
336336 # data resulting from step, when it's not done, after step_mdp
337337 r_not_done_tp1 = rollout [:, 1 :][~ done ]
338338 torch .testing .assert_close (
@@ -343,17 +343,15 @@ def rollout_consistency_assertion(
343343
344344 if done_strict and not done .any ():
345345 raise RuntimeError ("No done detected, test could not complete." )
346-
347- # data resulting from step, when it's done
348- r_done = rollout [:, :- 1 ]["next" ][done ]
349- # data resulting from step, when it's done, after step_mdp and reset
350- r_done_tp1 = rollout [:, 1 :][done ]
351- assert (
352- (r_done [observation_key ] - r_done_tp1 [observation_key ]).norm (dim = - 1 ) > 1e-1
353- ).all (), (
354- f"Entries in next tensordict do not match entries in root "
355- f"tensordict after reset : { (r_done [observation_key ] - r_done_tp1 [observation_key ]).norm (dim = - 1 ) < 1e-1 } "
356- )
346+ if done .any ():
347+ # data resulting from step, when it's done
348+ r_done = rollout [..., :- 1 ]["next" ][done ]
349+ # data resulting from step, when it's done, after step_mdp and reset
350+ r_done_tp1 = rollout [..., 1 :][done ]
351+ # check that at least one obs after reset does not match the version before reset
352+ assert not torch .isclose (
353+ r_done [observation_key ], r_done_tp1 [observation_key ]
354+ ).all ()
357355
358356
359357def rand_reset (env ):
0 commit comments