Skip to content

Commit ef793fd

Browse files
Alexander AminiAlexander Amini
authored andcommitted
saving cleanup
1 parent 33cd8e3 commit ef793fd

File tree

2 files changed

+10
-15
lines changed

2 files changed

+10
-15
lines changed

lab3/solutions/pong.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,5 @@ def collect_rollout(batch_size, env, model, choose_action):
369369
actions=np.array(batch_memory.actions),
370370
discounted_rewards=discount_rewards(batch_memory.rewards))
371371

372-
batch_memory.clear()
373-
# break
374-
375-
# observation = next_observation
376-
# previous_frame = current_frame
372+
if i_episode % 500 == 0:
373+
mdl.save_video_of_model(pong_model, "Pong-v0", suffix=str(i_episode))

mitdeeplearning/lab3.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,17 @@ def pong_change(prev, curr):
4444

4545

4646

47-
def save_video_of_model(model, env_name, obs_diff=False, pp_fn=None):
47+
def save_video_of_model(model, env_name, suffix=""):
4848
import skvideo.io
4949
from pyvirtualdisplay import Display
5050
display = Display(visible=0, size=(400, 300))
5151
display.start()
5252

53-
if pp_fn is None:
54-
pp_fn = lambda x: x
55-
5653
env = gym.make(env_name)
5754
obs = env.reset()
58-
obs = pp_fn(obs)
5955
prev_obs = obs
6056

61-
filename = env_name + ".mp4"
57+
filename = env_name + suffix + ".mp4"
6258
output_video = skvideo.io.FFmpegWriter(filename)
6359

6460
counter = 0
@@ -67,15 +63,17 @@ def save_video_of_model(model, env_name, obs_diff=False, pp_fn=None):
6763
frame = env.render(mode='rgb_array')
6864
output_video.writeFrame(frame)
6965

70-
if obs_diff:
71-
input_obs = obs - prev_obs
72-
else:
66+
if "Cartpole" in env_name:
7367
input_obs = obs
68+
elif "Pong" in env_name:
69+
input_obs = pong_change(prev_obs, obs)
70+
else:
71+
raise ValueError(f"Unknown env for saving: {env_name}")
72+
7473
action = model(np.expand_dims(input_obs, 0)).numpy().argmax()
7574

7675
prev_obs = obs
7776
obs, reward, done, info = env.step(action)
78-
obs = pp_fn(obs)
7977
counter += 1
8078

8179
output_video.close()

0 commit comments

Comments
 (0)