Skip to content

Commit 86b0b5b

Browse files
committed
typo
1 parent ef793fd commit 86b0b5b

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

lab3/solutions/pong.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@
4545
# Returns:
4646
# action: choice of agent action
4747
def choose_action(model, observation, single=True):
48-
if single: # create a batch dimension if only a single example was provided
49-
observations = np.expand_dims(observation, axis=0)
48+
# create a batch dimension if only a single example was provided
49+
observations = np.expand_dims(observation, axis=0) if single else observation
5050

5151
# add batch dimension to the observation
5252
# observation = np.expand_dims(observation, axis=0)
@@ -303,8 +303,8 @@ def collect_rollout(batch_size, env, model, choose_action):
303303
return memories
304304

305305

306-
mdl.lab3.save_video_of_memory(memory[0])
307-
collect_rollout(batch_size, env, model, choose_action)
306+
# mdl.lab3.save_video_of_memory(memory[0])
307+
# collect_rollout(batch_size, env, model, choose_action)
308308

309309

310310

@@ -315,8 +315,8 @@ def collect_rollout(batch_size, env, model, choose_action):
315315

316316

317317
tic = time.time()
318-
memories = collect_rollout(batch_size, env, pong_model, choose_action)
319-
# memories = parallelized_collect_rollout(batch_size, envs, pong_model, choose_action)
318+
# memories = collect_rollout(batch_size, env, pong_model, choose_action)
319+
memories = parallelized_collect_rollout(batch_size, envs, pong_model, choose_action)
320320
batch_memory = aggregate_memories(memories)
321321
print(time.time()-tic)
322322

@@ -370,4 +370,4 @@ def collect_rollout(batch_size, env, model, choose_action):
370370
discounted_rewards=discount_rewards(batch_memory.rewards))
371371

372372
if i_episode % 500 == 0:
373-
mdl.save_video_of_model(pong_model, "Pong-v0", suffix=str(i_episode))
373+
mdl.lab3.save_video_of_model(pong_model, "Pong-v0", suffix=str(i_episode))

0 commit comments

Comments
 (0)