45
45
# Returns:
46
46
# action: choice of agent action
47
47
def choose_action (model , observation , single = True ):
48
- if single : # create a batch dimension if only a single example was provided
49
- observations = np .expand_dims (observation , axis = 0 )
48
+ # create a batch dimension if only a single example was provided
49
+ observations = np .expand_dims (observation , axis = 0 ) if single else observation
50
50
51
51
# add batch dimension to the observation
52
52
# observation = np.expand_dims(observation, axis=0)
@@ -303,8 +303,8 @@ def collect_rollout(batch_size, env, model, choose_action):
303
303
return memories
304
304
305
305
306
- mdl .lab3 .save_video_of_memory (memory [0 ])
307
- collect_rollout (batch_size , env , model , choose_action )
306
+ # mdl.lab3.save_video_of_memory(memory[0])
307
+ # collect_rollout(batch_size, env, model, choose_action)
308
308
309
309
310
310
@@ -315,8 +315,8 @@ def collect_rollout(batch_size, env, model, choose_action):
315
315
316
316
317
317
tic = time .time ()
318
- memories = collect_rollout (batch_size , env , pong_model , choose_action )
319
- # memories = parallelized_collect_rollout(batch_size, envs, pong_model, choose_action)
318
+ # memories = collect_rollout(batch_size, env, pong_model, choose_action)
319
+ memories = parallelized_collect_rollout (batch_size , envs , pong_model , choose_action )
320
320
batch_memory = aggregate_memories (memories )
321
321
print (time .time ()- tic )
322
322
@@ -370,4 +370,4 @@ def collect_rollout(batch_size, env, model, choose_action):
370
370
discounted_rewards = discount_rewards (batch_memory .rewards ))
371
371
372
372
if i_episode % 500 == 0 :
373
- mdl .save_video_of_model (pong_model , "Pong-v0" , suffix = str (i_episode ))
373
+ mdl .lab3 . save_video_of_model (pong_model , "Pong-v0" , suffix = str (i_episode ))
0 commit comments