@@ -567,7 +567,9 @@ def __init__(
567567 self .policy_weights = TensorDict ({}, [])
568568
569569 self .env : EnvBase = self .env .to (self .device )
570- self .max_frames_per_traj = max_frames_per_traj
570+ self .max_frames_per_traj = (
571+ int (max_frames_per_traj ) if max_frames_per_traj is not None else 0
572+ )
571573 if self .max_frames_per_traj is not None and self .max_frames_per_traj > 0 :
572574 # let's check that there is no StepCounter yet
573575 for key in self .env .output_spec .keys (True , True ):
@@ -595,9 +597,13 @@ def __init__(
595597 f"This means { frames_per_batch - remainder } additional frames will be collected."
596598 "To silence this message, set the environment variable RL_WARNINGS to False."
597599 )
598- self .total_frames = total_frames
600+ self .total_frames = (
601+ int (total_frames ) if total_frames != float ("inf" ) else total_frames
602+ )
599603 self .reset_at_each_iter = reset_at_each_iter
600- self .init_random_frames = init_random_frames
604+ self .init_random_frames = (
605+ int (init_random_frames ) if init_random_frames is not None else 0
606+ )
601607 if (
602608 init_random_frames is not None
603609 and init_random_frames % frames_per_batch != 0
@@ -620,7 +626,7 @@ def __init__(
620626 f" ({ - (- frames_per_batch // self .n_env ) * self .n_env } )."
621627 "To silence this message, set the environment variable RL_WARNINGS to False."
622628 )
623- self .requested_frames_per_batch = frames_per_batch
629+ self .requested_frames_per_batch = int ( frames_per_batch )
624630 self .frames_per_batch = - (- frames_per_batch // self .n_env )
625631 self .exploration_type = (
626632 exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE
@@ -1234,11 +1240,15 @@ def device_err_msg(device_name, devices_list):
12341240 f"This means { frames_per_batch - remainder } additional frames will be collected."
12351241 "To silence this message, set the environment variable RL_WARNINGS to False."
12361242 )
1237- self .total_frames = total_frames
1243+ self .total_frames = (
1244+ int (total_frames ) if total_frames != float ("inf" ) else total_frames
1245+ )
12381246 self .reset_at_each_iter = reset_at_each_iter
12391247 self .postprocs = postproc
1240- self .max_frames_per_traj = max_frames_per_traj
1241- self .requested_frames_per_batch = frames_per_batch
1248+ self .max_frames_per_traj = (
1249+ int (max_frames_per_traj ) if max_frames_per_traj is not None else 0
1250+ )
1251+ self .requested_frames_per_batch = int (frames_per_batch )
12421252 self .reset_when_done = reset_when_done
12431253 if split_trajs is None :
12441254 split_trajs = False
@@ -1247,7 +1257,9 @@ def device_err_msg(device_name, devices_list):
12471257 "Cannot split trajectories when reset_when_done is False."
12481258 )
12491259 self .split_trajs = split_trajs
1250- self .init_random_frames = init_random_frames
1260+ self .init_random_frames = (
1261+ int (init_random_frames ) if init_random_frames is not None else 0
1262+ )
12511263 self .update_at_each_batch = update_at_each_batch
12521264 self .exploration_type = exploration_type
12531265 self .frames_per_worker = np .inf
0 commit comments