Skip to content

Commit 758a3eb

Browse files
authored
adding parallelization support
1 parent 8e82823 commit 758a3eb

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

mitdeeplearning/lab3.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,43 @@ def pong_change(prev, curr):
4141
I = (I - I.min()) / (I.max() - I.min() + 1e-10)
4242
return I
4343

44+
def parallelized_collect_rollout(batch_size, envs, model, choose_action):
4445

46+
assert len(envs) == batch_size, "Number of parallel environments must be equal to the batch size."
47+
48+
memories = [Memory() for _ in range(batch_size)]
49+
next_observations = [single_env.reset() for single_env in envs]
50+
previous_frames = [obs for obs in next_observations]
51+
done = [False] * batch_size
52+
rewards = [0] * batch_size
53+
54+
tic = time.time()
55+
while True:
56+
57+
current_frames = [obs for obs in next_observations]
58+
diff_frames = [pong_change(prev, curr) for (prev, curr) in zip(previous_frames, current_frames)]
59+
60+
diff_frames_not_done = [diff_frames[b] for b in range(batch_size) if not done[b]]
61+
actions_not_done = choose_action(model, np.array(diff_frames_not_done), single=False)
62+
63+
actions = [None] * batch_size
64+
ind_not_done = 0
65+
for b in range(batch_size):
66+
if not done[b]:
67+
actions[b] = actions_not_done[ind_not_done]
68+
ind_not_done += 1
69+
70+
for b in range(batch_size):
71+
if done[b]:
72+
continue
73+
next_observations[b], rewards[b], done[b], info = envs[b].step(actions[b])
74+
previous_frames[b] = current_frames[b]
75+
memories[b].add_to_memory(diff_frames[b], actions[b], rewards[b])
76+
77+
if all(done):
78+
break
79+
80+
return memories
4581

4682

4783
def save_video_of_model(model, env_name, suffix=""):

0 commit comments

Comments
 (0)