-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdeployGamerPlayer.py
More file actions
121 lines (77 loc) · 2.59 KB
/
deployGamerPlayer.py
File metadata and controls
121 lines (77 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python
# coding: utf-8
# In[7]:
import gym
import numpy as np
import tensorflow as tf
from tensorflow import keras
import sys
import os
from tf_agents.environments import suite_gym
import matplotlib.pyplot as plt
from tf_agents.environments.wrappers import ActionRepeat
from gym.wrappers import TimeLimit
from tf_agents.environments import suite_atari
from tf_agents.environments.atari_preprocessing import AtariPreprocessing
from tf_agents.environments.atari_wrappers import FrameStack4
from tf_agents.networks.q_network import QNetwork
from tf_agents.environments.tf_py_environment import TFPyEnvironment
from tf_agents.agents.dqn.dqn_agent import DqnAgent
from tf_agents.environments import suite_atari
from tf_agents.environments.atari_preprocessing import AtariPreprocessing
from tf_agents.environments.atari_wrappers import FrameStack4
from tf_agents.environments import suite_gym
from tf_agents.drivers.dynamic_step_driver import DynamicStepDriver
# In[3]:
max_episode_steps = 27000
environment_name = "SpaceInvadersNoFrameskip-v4"
env = suite_atari.load(
environment_name,
max_episode_steps=max_episode_steps,
gym_env_wrappers=[AtariPreprocessing, FrameStack4])
# In[4]:
from tf_agents.environments.tf_py_environment import TFPyEnvironment
tf_env = TFPyEnvironment(env)
# In[5]:
saved_policy = tf.compat.v2.saved_model.load('savedPolicy')
# In[ ]:
# In[11]:
def update_scene(num, frames, patch):
patch.set_data(frames[num])
return patch,
def plot_animation(frames, repeat=False, interval=40):
fig = plt.figure()
patch = plt.imshow(frames[0])
plt.axis('off')
anim = animation.FuncAnimation(
fig, update_scene, fargs=(frames, patch),
frames=len(frames), repeat=repeat, interval=interval)
plt.close()
return anim
# In[12]:
class ShowProgress:
def __init__(self, total):
self.counter = 0
self.total = total
def __call__(self, trajectory):
if not trajectory.is_boundary():
self.counter += 1
if self.counter % 100 == 0:
print("\r{}/{}".format(self.counter, self.total), end="")
# In[13]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as animation
mpl.rc('animation', html='jshtml')
frames = []
def save_frames(trajectory):
global frames
frames.append(tf_env.pyenv.envs[0].render(mode="rgb_array"))
watch_driver = DynamicStepDriver(
tf_env,
saved_policy,
observers=[save_frames, ShowProgress(1000)],
num_steps=1000)
final_time_step, final_policy_state = watch_driver.run()
plot_animation(frames)
# In[ ]: