Skip to content

Commit 33cd8e3

Browse files
Alexander AminiAlexander Amini
authored andcommitted
updates to parallel rollouts
1 parent e373962 commit 33cd8e3

File tree

2 files changed

+90
-68
lines changed

2 files changed

+90
-68
lines changed

lab3/solutions/pong.py

Lines changed: 73 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
tf.config.experimental.set_memory_growth(physical_devices[0], True)
3030

3131

32-
env = gym.make("Pong-v0", frameskip=5, difficulty=0)
32+
env = gym.make("Pong-v0", frameskip=5)
3333
env.seed(1) # for reproducibility
3434

3535
n_actions = env.action_space.n
@@ -44,26 +44,25 @@
4444
# observation: observation which is fed as input to the model
4545
# Returns:
4646
# action: choice of agent action
47-
def choose_action(model, observations):
47+
def choose_action(model, observation, single=True):
48+
if single: # create a batch dimension if only a single example was provided
49+
observations = np.expand_dims(observation, axis=0)
50+
4851
# add batch dimension to the observation
4952
# observation = np.expand_dims(observation, axis=0)
5053
'''TODO: feed the observations through the model to predict the log probabilities of each possible action.'''
51-
5254
logits = model.predict(observations) # TODO
5355
# logits = model.predict('''TODO''')
5456

5557
# pass the log probabilities through a softmax to compute true probabilities
56-
prob_weights = tf.nn.softmax(logits)
58+
# prob_weights = tf.nn.softmax(logits)
5759
'''TODO: randomly sample from the prob_weights to pick an action.
5860
Hint: carefully consider the dimensionality of the input probabilities (vector) and the output action (scalar)'''
5961

60-
action = tf.random.categorical(logits, 1)[:,0].numpy()
62+
action = tf.random.categorical(logits, num_samples=1)
63+
action = action.numpy().flatten()
6164

62-
# action = np.random.choice(
63-
# n_actions, size=1, p=prob_weights.flatten())[0] # TODO
64-
# action = np.random.choice('''TODO''', size=1, p=''''TODO''')['''TODO''']
65-
66-
return action
65+
return action[0] if single else action
6766

6867

6968
### Reward function ###
@@ -99,6 +98,15 @@ def add_to_memory(self, new_observation, new_action, new_reward):
9998
# ['''TODO''']
10099

101100

101+
# Combine a list of Memory objects into a single Memory (e.g., for batching)
102+
def aggregate_memories(memories):
103+
batch_memory = Memory()
104+
for memory in memories:
105+
for step in zip(memory.observations, memory.actions, memory.rewards):
106+
batch_memory.add_to_memory(*step)
107+
return batch_memory
108+
109+
102110
memory = Memory()
103111

104112
### Loss function ###
@@ -200,14 +208,6 @@ def discount_rewards(rewards, gamma=0.99):
200208
return normalize(discounted_rewards)
201209

202210

203-
def fix(img):
204-
return cv2.resize(
205-
cv2.dilate(img, np.ones((2, 2), np.uint8), iterations=1),
206-
None,
207-
fx=0.5,
208-
fy=0.5)[:, :, np.newaxis]
209-
210-
211211
# env.reset()
212212
# for i in range(1000):
213213
# observation, _,_,_ = env.step(0)
@@ -221,11 +221,10 @@ def fix(img):
221221

222222
# Hyperparameters
223223
learning_rate = args.learning_rate
224-
MAX_ITERS = 10000 # increase the maximum number of episodes, since Pong is more complex!
224+
MAX_ITERS = 100 # increase the maximum number of episodes, since Pong is more complex!
225225

226226
# Model and optimizer
227227
pong_model = create_pong_model()
228-
pong_model.build((None, 40, 40, 1))
229228

230229
optimizer = tf.keras.optimizers.Adam(learning_rate)
231230

@@ -238,49 +237,15 @@ def fix(img):
238237

239238

240239

240+
def parallelized_collect_rollout(batch_size, envs, model, choose_action):
241241

242+
assert len(envs) == batch_size, "Number of parallel environments must be equal to the batch size."
242243

243-
244-
# def run_episode(env, model):
245-
# print("running episode")
246-
# memory = Memory()
247-
# observation = env.reset()
248-
# previous_frame = fix(mdl.lab3.preprocess_pong(observation))
249-
# done = False
250-
# while not done:
251-
# # Pre-process image
252-
# current_frame = fix(mdl.lab3.preprocess_pong(observation))
253-
# obs_change = current_frame - previous_frame # TODO
254-
#
255-
# # obs_change = # TODO
256-
# tic = time.time()
257-
# action = choose_action(model, obs_change) # TODO
258-
#
259-
# # action = # TODO
260-
# # Take the chosen action
261-
# tic = time.time()
262-
# next_observation, reward, done, info = env.step(action)
263-
#
264-
# memory.add_to_memory(obs_change, action, reward) # TODO
265-
#
266-
# observation = next_observation
267-
# previous_frame = current_frame
268-
# return memory
269-
270-
271-
envs = [copy.deepcopy(env) for _ in range(batch_size)]
272-
273-
for i_episode in range(MAX_ITERS):
274-
275-
tic = time.time()
276244
memories = [Memory() for _ in range(batch_size)]
277245
next_observations = [single_env.reset() for single_env in envs]
278246
previous_frames = [obs for obs in next_observations]
279247
done = [False] * batch_size
280-
actions = [0] * batch_size
281248
rewards = [0] * batch_size
282-
print("reiniting", time.time()-tic)
283-
284249

285250
tic = time.time()
286251
while True:
@@ -289,7 +254,7 @@ def fix(img):
289254
diff_frames = [mdl.lab3.pong_change(prev, curr) for (prev, curr) in zip(previous_frames, current_frames)]
290255

291256
diff_frames_not_done = [diff_frames[b] for b in range(batch_size) if not done[b]]
292-
actions_not_done = choose_action(pong_model, np.array(diff_frames_not_done))
257+
actions_not_done = choose_action(model, np.array(diff_frames_not_done), single=False)
293258

294259
actions = [None] * batch_size
295260
ind_not_done = 0
@@ -305,10 +270,56 @@ def fix(img):
305270
previous_frames[b] = current_frames[b]
306271
memories[b].add_to_memory(diff_frames[b], actions[b], rewards[b])
307272

308-
309273
if all(done):
310274
break
311275

276+
return memories
277+
278+
279+
280+
def collect_rollout(batch_size, env, model, choose_action):
281+
282+
memories = []
283+
284+
for b in range(batch_size):
285+
memory = Memory()
286+
next_observation = env.reset()
287+
previous_frame = next_observation
288+
done = False
289+
290+
while not done:
291+
current_frame = next_observation
292+
diff_frame = mdl.lab3.pong_change(previous_frame, current_frame)
293+
294+
action = choose_action(model, diff_frame)
295+
296+
next_observation, reward, done, info = env.step(action)
297+
298+
previous_frame = current_frame
299+
memory.add_to_memory(diff_frame, action, reward)
300+
301+
memories.append(memory)
302+
303+
return memories
304+
305+
306+
mdl.lab3.save_video_of_memory(memory[0])
307+
collect_rollout(batch_size, env, model, choose_action)
308+
309+
310+
311+
312+
envs = [copy.deepcopy(env) for _ in range(batch_size)]
313+
314+
for i_episode in range(MAX_ITERS):
315+
316+
317+
tic = time.time()
318+
memories = collect_rollout(batch_size, env, pong_model, choose_action)
319+
# memories = parallelized_collect_rollout(batch_size, envs, pong_model, choose_action)
320+
batch_memory = aggregate_memories(memories)
321+
print(time.time()-tic)
322+
312323

313324
# def parallel_episode(i):
314325
# return run_episode(env=copy.deepcopy(env), model=pong_model)
@@ -323,19 +334,13 @@ def fix(img):
323334
# memories = pool.map(parallel_episode, models)#range(batch_size))
324335
# print(time.time()-tic)
325336

326-
print(time.time()-tic)
327-
328-
batch_memory = Memory()
329-
for memory in memories:
330-
for step in zip(memory.observations, memory.actions, memory.rewards):
331-
batch_memory.add_to_memory(*step)
332337

338+
# batch_memory = Memory()
339+
# for memory in memories:
340+
# for step in zip(memory.observations, memory.actions, memory.rewards):
341+
# batch_memory.add_to_memory(*step)
333342

334343

335-
def play(memory):
336-
for o in memory.observations:
337-
cv2.imshow('hi', cv2.resize(o, (500,500)))
338-
cv2.waitKey(20)
339344

340345

341346

mitdeeplearning/lab3.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,20 @@ def save_video_of_model(model, env_name, obs_diff=False, pp_fn=None):
8181
output_video.close()
8282
print("Successfully saved {} frames into {}!".format(counter, filename))
8383
return filename
84+
85+
86+
def save_video_of_memory(memory):
87+
import skvideo.io
88+
from pyvirtualdisplay import Display
89+
display = Display(visible=0, size=(400, 300))
90+
display.start()
91+
92+
filename = env_name + ".mp4"
93+
output_video = skvideo.io.FFmpegWriter(filename)
94+
95+
for observation in memory.observations:
96+
output_video.writeFrame(observation)
97+
98+
output_video.close()
99+
print("Successfully saved {} frames into {}!".format(counter, filename))
100+
return filename

0 commit comments

Comments
 (0)