Skip to content

Commit e373962

Browse files
Alexander AminiAlexander Amini
authored andcommitted
parallelizing pong training
1 parent 0428dfe commit e373962

File tree

3 files changed

+253
-112
lines changed

3 files changed

+253
-112
lines changed

lab3/solutions/pong.py

Lines changed: 82 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import argparse
44
import multiprocessing
55
from multiprocessing import Pool
6+
import os
7+
os.environ['KMP_DUPLICATE_LIB_OK']='True'
8+
import copy
69

710

811
import tensorflow as tf
@@ -21,48 +24,11 @@
2124
print(args)
2225

2326

24-
from tensorflow.keras.models import Sequential, Model
25-
from tensorflow.keras.layers import Dense
26-
from tensorflow.python.keras.layers import deserialize, serialize
27-
from tensorflow.python.keras.saving import saving_utils
28-
29-
30-
def unpack(model, training_config, weights):
31-
restored_model = deserialize(model)
32-
if training_config is not None:
33-
restored_model.compile(
34-
**saving_utils.compile_args_from_training_config(
35-
training_config
36-
)
37-
)
38-
restored_model.set_weights(weights)
39-
return restored_model
40-
41-
# Hotfix function
42-
def make_keras_picklable():
43-
44-
def __reduce__(self):
45-
model_metadata = saving_utils.model_metadata(self)
46-
training_config = model_metadata.get("training_config", None)
47-
model = serialize(self)
48-
weights = self.get_weights()
49-
return (unpack, (model, training_config, weights))
50-
51-
cls = Model
52-
cls.__reduce__ = __reduce__
53-
54-
# Run the function
55-
make_keras_picklable()
56-
57-
58-
5927
physical_devices = tf.config.experimental.list_physical_devices('GPU')
6028
if len(physical_devices) > 0:
6129
tf.config.experimental.set_memory_growth(physical_devices[0], True)
6230

6331

64-
65-
6632
env = gym.make("Pong-v0", frameskip=5, difficulty=0)
6733
env.seed(1) # for reproducibility
6834

@@ -78,27 +44,30 @@ def __reduce__(self):
7844
# observation: observation which is fed as input to the model
7945
# Returns:
8046
# action: choice of agent action
81-
def choose_action(model, observation):
47+
def choose_action(model, observations):
8248
# add batch dimension to the observation
83-
observation = np.expand_dims(observation, axis=0)
49+
# observation = np.expand_dims(observation, axis=0)
8450
'''TODO: feed the observations through the model to predict the log probabilities of each possible action.'''
85-
logits = model.predict(observation) # TODO
51+
52+
logits = model.predict(observations) # TODO
8653
# logits = model.predict('''TODO''')
8754

8855
# pass the log probabilities through a softmax to compute true probabilities
89-
prob_weights = tf.nn.softmax(logits).numpy()
56+
prob_weights = tf.nn.softmax(logits)
9057
'''TODO: randomly sample from the prob_weights to pick an action.
9158
Hint: carefully consider the dimensionality of the input probabilities (vector) and the output action (scalar)'''
92-
action = np.random.choice(
93-
n_actions, size=1, p=prob_weights.flatten())[0] # TODO
59+
60+
action = tf.random.categorical(logits, 1)[:,0].numpy()
61+
62+
# action = np.random.choice(
63+
# n_actions, size=1, p=prob_weights.flatten())[0] # TODO
9464
# action = np.random.choice('''TODO''', size=1, p=''''TODO''')['''TODO''']
9565

9666
return action
9767

9868

9969
### Reward function ###
10070

101-
10271
# Helper function that normalizes an np.array x
10372
def normalize(x):
10473
x -= np.mean(x)
@@ -109,7 +78,6 @@ def normalize(x):
10978

11079
### Agent Memory ###
11180

112-
11381
class Memory:
11482
def __init__(self):
11583
self.clear()
@@ -258,7 +226,6 @@ def fix(img):
258226
# Model and optimizer
259227
pong_model = create_pong_model()
260228
pong_model.build((None, 40, 40, 1))
261-
pong_model.save("model.h5")
262229

263230
optimizer = tf.keras.optimizers.Adam(learning_rate)
264231

@@ -267,100 +234,109 @@ def fix(img):
267234
smoothed_reward.append(-21) # start the reward at the minimum (0-21) for baseline comparison
268235
# plotter = mdl.util.PeriodicPlotter(
269236
# sec=5, xlabel='Iterations', ylabel='Rewards')
270-
memory = Memory()
271237
batch_size = args.batch_size
272-
batches = 0
273238

274239

275-
def run_episode(env, model):
276-
("running episode")
277-
memory = Memory()
278-
observation = env.reset()
279-
previous_frame = fix(mdl.lab3.preprocess_pong(observation))
280-
done = False
281-
while not done:
282-
# Pre-process image
283-
current_frame = fix(mdl.lab3.preprocess_pong(observation))
284-
obs_change = current_frame - previous_frame # TODO
285240

286-
# obs_change = # TODO
287-
action = choose_action(model, obs_change) # TODO
288241

289-
# action = # TODO
290-
# Take the chosen action
291-
next_observation, reward, done, info = env.step(action)
292242

293-
memory.add_to_memory(obs_change, action, reward) # TODO
294243

295-
observation = next_observation
296-
previous_frame = current_frame
297-
return memory
244+
# def run_episode(env, model):
245+
# print("running episode")
246+
# memory = Memory()
247+
# observation = env.reset()
248+
# previous_frame = fix(mdl.lab3.preprocess_pong(observation))
249+
# done = False
250+
# while not done:
251+
# # Pre-process image
252+
# current_frame = fix(mdl.lab3.preprocess_pong(observation))
253+
# obs_change = current_frame - previous_frame # TODO
254+
#
255+
# # obs_change = # TODO
256+
# tic = time.time()
257+
# action = choose_action(model, obs_change) # TODO
258+
#
259+
# # action = # TODO
260+
# # Take the chosen action
261+
# tic = time.time()
262+
# next_observation, reward, done, info = env.step(action)
263+
#
264+
# memory.add_to_memory(obs_change, action, reward) # TODO
265+
#
266+
# observation = next_observation
267+
# previous_frame = current_frame
268+
# return memory
298269

299270

271+
envs = [copy.deepcopy(env) for _ in range(batch_size)]
300272

301273
for i_episode in range(MAX_ITERS):
302274

303-
# plotter.plot(smoothed_reward.get())
275+
tic = time.time()
276+
memories = [Memory() for _ in range(batch_size)]
277+
next_observations = [single_env.reset() for single_env in envs]
278+
previous_frames = [obs for obs in next_observations]
279+
done = [False] * batch_size
280+
actions = [0] * batch_size
281+
rewards = [0] * batch_size
282+
print("reiniting", time.time()-tic)
304283

305-
# # Restart the environment
306-
# observation = env.reset()
307-
# previous_frame = fix(mdl.lab3.preprocess_pong(observation))
308-
# tic = time.time()
309-
# while True:
310-
# # Pre-process image
311-
# current_frame = fix(mdl.lab3.preprocess_pong(observation))
312-
# '''TODO: determine the observation change
313-
# Hint: this is the difference between the past two frames'''
314-
# obs_change = current_frame - previous_frame # TODO
315-
#
316-
#
317-
#
318-
# # obs_change = # TODO
319-
# '''TODO: choose an action for the pong model, using the frame difference, and evaluate'''
320-
# action = choose_action(pong_model, obs_change) # TODO
321-
# # action = # TODO
322-
# # Take the chosen action
323-
# next_observation, reward, done, info = env.step(action)
324-
# '''TODO: save the observed frame difference, the action that was taken, and the resulting reward!'''
325-
# memory.add_to_memory(obs_change, action, reward) # TODO
326-
#
327-
# if len(memory.actions) % 3 == 0 and args.draw:
328-
# z = obs_change
329-
# z = (z-z.min())/ (z.max()-z.min()+1e-6)
330-
# cv2.imshow('hi', cv2.resize(z, (256, 256)))
331-
# cv2.waitKey(1)
332284

285+
tic = time.time()
286+
while True:
287+
288+
current_frames = [obs for obs in next_observations]
289+
diff_frames = [mdl.lab3.pong_change(prev, curr) for (prev, curr) in zip(previous_frames, current_frames)]
290+
291+
diff_frames_not_done = [diff_frames[b] for b in range(batch_size) if not done[b]]
292+
actions_not_done = choose_action(pong_model, np.array(diff_frames_not_done))
293+
294+
actions = [None] * batch_size
295+
ind_not_done = 0
296+
for b in range(batch_size):
297+
if not done[b]:
298+
actions[b] = actions_not_done[ind_not_done]
299+
ind_not_done += 1
300+
301+
for b in range(batch_size):
302+
if done[b]:
303+
continue
304+
next_observations[b], rewards[b], done[b], info = envs[b].step(actions[b])
305+
previous_frames[b] = current_frames[b]
306+
memories[b].add_to_memory(diff_frames[b], actions[b], rewards[b])
333307

334308

335-
import copy
309+
if all(done):
310+
break
336311

337-
def parallel_episode(new_model):
338-
print("insdie paralel")
339-
# new_model = tf.keras.models.load_model('model.h5')
340-
print(new_model)
341-
return run_episode(env=copy.deepcopy(env), model=new_model)
342312

313+
# def parallel_episode(i):
314+
# return run_episode(env=copy.deepcopy(env), model=pong_model)
315+
#
343316
# tic = time.time()
344317
# memories = [parallel_episode(batch) for batch in range(batch_size)]
345318
# print(time.time()-tic)
346319

347-
models = [tf.keras.models.load_model('model.h5') for b in range(batch_size)]
348-
tic = time.time()
349-
with Pool(processes=batch_size) as pool:
350-
memories = pool.map(parallel_episode, models)#range(batch_size))
320+
# models = [tf.keras.models.load_model('model.h5') for b in range(batch_size)]
321+
# tic = time.time()
322+
# with Pool(processes=batch_size) as pool:
323+
# memories = pool.map(parallel_episode, models)#range(batch_size))
324+
# print(time.time()-tic)
325+
351326
print(time.time()-tic)
352327

353328
batch_memory = Memory()
354329
for memory in memories:
355330
for step in zip(memory.observations, memory.actions, memory.rewards):
356331
batch_memory.add_to_memory(*step)
357332

333+
334+
358335
def play(memory):
359336
for o in memory.observations:
360337
cv2.imshow('hi', cv2.resize(o, (500,500)))
361338
cv2.waitKey(20)
362339

363-
# import pdb; pdb.set_trace()
364340

365341

366342
### Train with this batch!!!
@@ -380,9 +356,6 @@ def play(memory):
380356
last_smoothed_reward = smoothed_reward.get()[-1]
381357
print(f"{iters} \t {round(last_smoothed_reward, 3)}")
382358

383-
tf.keras.backend.clear_session()
384-
pong_model = tf.keras.models.load_model('model.h5')
385-
386359
# begin training
387360
train_step(
388361
pong_model,
@@ -391,9 +364,6 @@ def play(memory):
391364
actions=np.array(batch_memory.actions),
392365
discounted_rewards=discount_rewards(batch_memory.rewards))
393366

394-
tf.keras.backend.clear_session()
395-
del pong_model
396-
397367
batch_memory.clear()
398368
# break
399369

0 commit comments

Comments
 (0)