@@ -41,7 +41,43 @@ def pong_change(prev, curr):
41
41
I = (I - I .min ()) / (I .max () - I .min () + 1e-10 )
42
42
return I
43
43
44
+ def parallelized_collect_rollout (batch_size , envs , model , choose_action ):
44
45
46
+ assert len (envs ) == batch_size , "Number of parallel environments must be equal to the batch size."
47
+
48
+ memories = [Memory () for _ in range (batch_size )]
49
+ next_observations = [single_env .reset () for single_env in envs ]
50
+ previous_frames = [obs for obs in next_observations ]
51
+ done = [False ] * batch_size
52
+ rewards = [0 ] * batch_size
53
+
54
+ tic = time .time ()
55
+ while True :
56
+
57
+ current_frames = [obs for obs in next_observations ]
58
+ diff_frames = [pong_change (prev , curr ) for (prev , curr ) in zip (previous_frames , current_frames )]
59
+
60
+ diff_frames_not_done = [diff_frames [b ] for b in range (batch_size ) if not done [b ]]
61
+ actions_not_done = choose_action (model , np .array (diff_frames_not_done ), single = False )
62
+
63
+ actions = [None ] * batch_size
64
+ ind_not_done = 0
65
+ for b in range (batch_size ):
66
+ if not done [b ]:
67
+ actions [b ] = actions_not_done [ind_not_done ]
68
+ ind_not_done += 1
69
+
70
+ for b in range (batch_size ):
71
+ if done [b ]:
72
+ continue
73
+ next_observations [b ], rewards [b ], done [b ], info = envs [b ].step (actions [b ])
74
+ previous_frames [b ] = current_frames [b ]
75
+ memories [b ].add_to_memory (diff_frames [b ], actions [b ], rewards [b ])
76
+
77
+ if all (done ):
78
+ break
79
+
80
+ return memories
45
81
46
82
47
83
def save_video_of_model (model , env_name , suffix = "" ):
0 commit comments