Skip to content

Commit a7abff4

Browse files
committed
save vid lab3
1 parent e1aeb24 commit a7abff4

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

mitdeeplearning/lab3.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import io
22
import base64
33
from IPython.display import HTML
4-
4+
import gym
55

66
def play_video(filename):
77
encoded = base64.b64encode(io.open(filename, 'r+b').read())
@@ -19,3 +19,42 @@ def preprocess_pong(image):
1919
I[I == 109] = 0 # Remove background type 2
2020
I[I != 0] = 1 # Set remaining elements (paddles, ball, etc.) to 1
2121
return I.astype(np.float).reshape(80, 80, 1)
22+
23+
24+
def save_video_of_model(model, env_name, obs_diff=False, pp_fn=None):
25+
import skvideo.io
26+
from pyvirtualdisplay import Display
27+
display = Display(visible=0, size=(400, 300))
28+
display.start()
29+
30+
if pp_fn is None:
31+
pp_fn = lambda x: x
32+
33+
env = gym.make(env_name)
34+
obs = env.reset()
35+
obs = pp_fn(obs)
36+
prev_obs = obs
37+
38+
filename = env_name + ".mp4"
39+
output_video = skvideo.io.FFmpegWriter(filename)
40+
41+
counter = 0
42+
done = False
43+
while not done:
44+
frame = env.render(mode='rgb_array')
45+
output_video.writeFrame(frame)
46+
47+
if obs_diff:
48+
input_obs = obs - prev_obs
49+
else:
50+
input_obs = obs
51+
action = model(np.expand_dims(input_obs, 0)).numpy().argmax()
52+
53+
prev_obs = obs
54+
obs, reward, done, info = env.step(action)
55+
obs = pp_fn(obs)
56+
counter += 1
57+
58+
output_video.close()
59+
print("Successfully saved {} frames into {}!".format(counter, filename))
60+
return filename

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ def get_dist(pkgname):
2121
setup(
2222
name = 'mitdeeplearning', # How you named your package folder (MyLib)
2323
packages = ['mitdeeplearning'], # Chose the same as "name"
24-
version = '0.5.1', # Start with a small number and increase it with every change you make
24+
version = '0.5.2', # Start with a small number and increase it with every change you make
2525
license='MIT', # Chose a license from here: https://help.github.com/articles/licensing-a-repository
2626
description = 'Official software labs for MIT Introduction to Deep Learning (http://introtodeeplearning.com)', # Give a short description about your library
2727
author = 'Alexander Amini', # Type in your name
2828
author_email = '[email protected]', # Type in your E-Mail
2929
url = 'http://introtodeeplearning.com', # Provide either the link to your github or to your website
30-
download_url = 'https://github.com/aamini/introtodeeplearning_labs/archive/v0.5.1.tar.gz', # I explain this later on
30+
download_url = 'https://github.com/aamini/introtodeeplearning_labs/archive/v0.5.2.tar.gz', # I explain this later on
3131
keywords = ['deep learning', 'neural networks', 'tensorflow', 'introduction'], # Keywords that define your package best
3232
install_requires=install_deps,
3333
classifiers=[

0 commit comments

Comments
 (0)