1
+ {
2
+ "nbformat" : 4 ,
3
+ "nbformat_minor" : 0 ,
4
+ "metadata" : {
5
+ "colab" : {
6
+ "name" : " Colab-UnityEnvironment-1-Run.ipynb" ,
7
+ "private_outputs" : true ,
8
+ "provenance" : [],
9
+ "collapsed_sections" : [],
10
+ "toc_visible" : true
11
+ },
12
+ "kernelspec" : {
13
+ "name" : " python3" ,
14
+ "language" : " python" ,
15
+ "display_name" : " Python 3"
16
+ }
17
+ },
18
+ "cells" : [
19
+ {
20
+ "cell_type" : " markdown" ,
21
+ "metadata" : {
22
+ "id" : " pbVXrmEsLXDt"
23
+ },
24
+ "source" : [
25
+ " # ML-Agents run with Stable Baselines 3\n " ,
26
+ " <img src=\" https://github.com/Unity-Technologies/ml-agents/blob/release_19_docs/docs/images/image-banner.png?raw=true\" align=\" middle\" width=\" 435\" />"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type" : " markdown" ,
31
+ "metadata" : {
32
+ "id" : " WNKTwHU3d2-l"
33
+ },
34
+ "source" : [
35
+ " ## Setup"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type" : " code" ,
40
+ "execution_count" : null ,
41
+ "outputs" : [],
42
+ "source" : [
43
+ " #@title Install Rendering Dependencies { display-mode: \" form\" }\n " ,
44
+ " #@markdown (You only need to run this code when using Colab's hosted runtime)\n " ,
45
+ " \n " ,
46
+ " import os\n " ,
47
+ " from IPython.display import HTML, display\n " ,
48
+ " \n " ,
49
+ " def progress(value, max=100):\n " ,
50
+ " return HTML(\"\"\"\n " ,
51
+ " <progress\n " ,
52
+ " value='{value}'\n " ,
53
+ " max='{max}',\n " ,
54
+ " style='width: 100%'\n " ,
55
+ " >\n " ,
56
+ " {value}\n " ,
57
+ " </progress>\n " ,
58
+ " \"\"\" .format(value=value, max=max))\n " ,
59
+ " \n " ,
60
+ " pro_bar = display(progress(0, 100), display_id=True)\n " ,
61
+ " \n " ,
62
+ " try:\n " ,
63
+ " import google.colab\n " ,
64
+ " INSTALL_XVFB = True\n " ,
65
+ " except ImportError:\n " ,
66
+ " INSTALL_XVFB = 'COLAB_ALWAYS_INSTALL_XVFB' in os.environ\n " ,
67
+ " \n " ,
68
+ " if INSTALL_XVFB:\n " ,
69
+ " with open('frame-buffer', 'w') as writefile:\n " ,
70
+ " writefile.write(\"\"\" #taken from https://gist.github.com/jterrace/2911875\n " ,
71
+ " XVFB=/usr/bin/Xvfb\n " ,
72
+ " XVFBARGS=\" :1 -screen 0 1024x768x24 -ac +extension GLX +render -noreset\"\n " ,
73
+ " PIDFILE=./frame-buffer.pid\n " ,
74
+ " case \" $1\" in\n " ,
75
+ " start)\n " ,
76
+ " echo -n \" Starting virtual X frame buffer: Xvfb\"\n " ,
77
+ " /sbin/start-stop-daemon --start --quiet --pidfile $PIDFILE --make-pidfile --background --exec $XVFB -- $XVFBARGS\n " ,
78
+ " echo \" .\"\n " ,
79
+ " ;;\n " ,
80
+ " stop)\n " ,
81
+ " echo -n \" Stopping virtual X frame buffer: Xvfb\"\n " ,
82
+ " /sbin/start-stop-daemon --stop --quiet --pidfile $PIDFILE\n " ,
83
+ " rm $PIDFILE\n " ,
84
+ " echo \" .\"\n " ,
85
+ " ;;\n " ,
86
+ " restart)\n " ,
87
+ " $0 stop\n " ,
88
+ " $0 start\n " ,
89
+ " ;;\n " ,
90
+ " *)\n " ,
91
+ " echo \" Usage: /etc/init.d/xvfb {start|stop|restart}\"\n " ,
92
+ " exit 1\n " ,
93
+ " esac\n " ,
94
+ " exit 0\n " ,
95
+ " \"\"\" )\n " ,
96
+ " !sudo apt-get update\n " ,
97
+ " pro_bar.update(progress(10, 100))\n " ,
98
+ " !sudo DEBIAN_FRONTEND=noninteractive apt install -y daemon wget gdebi-core build-essential libfontenc1 libfreetype6 xorg-dev xorg\n " ,
99
+ " pro_bar.update(progress(20, 100))\n " ,
100
+ " !wget http://security.ubuntu.com/ubuntu/pool/main/libx/libxfont/libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb 2>&1\n " ,
101
+ " pro_bar.update(progress(30, 100))\n " ,
102
+ " !wget --output-document xvfb.deb http://security.ubuntu.com/ubuntu/pool/universe/x/xorg-server/xvfb_1.18.4-0ubuntu0.12_amd64.deb 2>&1\n " ,
103
+ " pro_bar.update(progress(40, 100))\n " ,
104
+ " !sudo dpkg -i libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb 2>&1\n " ,
105
+ " pro_bar.update(progress(50, 100))\n " ,
106
+ " !sudo dpkg -i xvfb.deb 2>&1\n " ,
107
+ " pro_bar.update(progress(70, 100))\n " ,
108
+ " !rm libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb\n " ,
109
+ " pro_bar.update(progress(80, 100))\n " ,
110
+ " !rm xvfb.deb\n " ,
111
+ " pro_bar.update(progress(90, 100))\n " ,
112
+ " !bash frame-buffer start\n " ,
113
+ " os.environ[\" DISPLAY\" ] = \" :1\"\n " ,
114
+ " pro_bar.update(progress(100, 100))"
115
+ ],
116
+ "metadata" : {
117
+ "collapsed" : false ,
118
+ "pycharm" : {
119
+ "name" : " #%%\n "
120
+ }
121
+ }
122
+ },
123
+ {
124
+ "cell_type" : " markdown" ,
125
+ "metadata" : {
126
+ "id" : " Pzj7wgapAcDs"
127
+ },
128
+ "source" : [
129
+ " ### Installing ml-agents"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type" : " code" ,
134
+ "metadata" : {
135
+ "id" : " N8yfQqkbebQ5" ,
136
+ "pycharm" : {
137
+ "is_executing" : true
138
+ }
139
+ },
140
+ "source" : [
141
+ " try:\n " ,
142
+ " import mlagents\n " ,
143
+ " print(\" ml-agents already installed\" )\n " ,
144
+ " except ImportError:\n " ,
145
+ " !python -m pip install -q mlagents==0.28.0\n " ,
146
+ " print(\" Installed ml-agents\" )"
147
+ ],
148
+ "execution_count" : null ,
149
+ "outputs" : []
150
+ },
151
+ {
152
+ "cell_type" : " markdown" ,
153
+ "metadata" : {
154
+ "id" : " _u74YhSmW6gD"
155
+ },
156
+ "source" : [
157
+ " ## Run the Environment"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type" : " markdown" ,
162
+ "metadata" : {
163
+ "id" : " P-r_cB2rqp5x"
164
+ },
165
+ "source" : [
166
+ " ### Import dependencies and set some high level parameters."
167
+ ]
168
+ },
169
+ {
170
+ "cell_type" : " code" ,
171
+ "metadata" : {
172
+ "id" : " YSf-WhxbqtLw"
173
+ },
174
+ "source" : [
175
+ " from math import ceil\n " ,
176
+ " \n " ,
177
+ " from stable_baselines3 import PPO\n " ,
178
+ " from stable_baselines3.common.vec_env import VecMonitor\n " ,
179
+ " \n " ,
180
+ " from mlagents_envs.envs.unity_vec_env import make_mla_sb3_env, LimitedConfig\n " ,
181
+ " \n " ,
182
+ " # 250K should train to a reward ~= 0.90 for the \" Basic\" environment.\n " ,
183
+ " # We set the value lower here to demonstrate just a small amount of trianing.\n " ,
184
+ " TOTAL_TAINING_STEPS_GOAL = 40 * 1000\n " ,
185
+ " NUM_ENVS = 12\n " ,
186
+ " STEPS_PER_UPDATE = 2048"
187
+ ],
188
+ "execution_count" : 29 ,
189
+ "outputs" : []
190
+ },
191
+ {
192
+ "cell_type" : " markdown" ,
193
+ "source" : [
194
+ " ### Start Environment from the registry"
195
+ ],
196
+ "metadata" : {
197
+ "collapsed" : false
198
+ }
199
+ },
200
+ {
201
+ "cell_type" : " code" ,
202
+ "execution_count" : null ,
203
+ "outputs" : [],
204
+ "source" : [
205
+ " # -----------------\n " ,
206
+ " # This code is used to close an env that might not have been closed before\n " ,
207
+ " try:\n " ,
208
+ " env.close()\n " ,
209
+ " except:\n " ,
210
+ " pass\n " ,
211
+ " # -----------------\n " ,
212
+ " \n " ,
213
+ " env = make_mla_sb3_env(\n " ,
214
+ " config=LimitedConfig(\n " ,
215
+ " env_path_or_name='Basic', # Can use any name from a registry or a path to your own unity build.\n " ,
216
+ " base_port=6006,\n " ,
217
+ " base_seed=42,\n " ,
218
+ " num_env=NUM_ENVS,\n " ,
219
+ " allow_multiple_obs=True,\n " ,
220
+ " ),\n " ,
221
+ " no_graphics=True, # Set to false if you are running locally and want to watch the environments move around as they train.\n " ,
222
+ " )"
223
+ ],
224
+ "metadata" : {
225
+ "collapsed" : false ,
226
+ "pycharm" : {
227
+ "name" : " #%%\n " ,
228
+ "is_executing" : true
229
+ }
230
+ }
231
+ },
232
+ {
233
+ "cell_type" : " markdown" ,
234
+ "source" : [
235
+ " ### Create the model"
236
+ ],
237
+ "metadata" : {
238
+ "collapsed" : false
239
+ }
240
+ },
241
+ {
242
+ "cell_type" : " code" ,
243
+ "execution_count" : null ,
244
+ "outputs" : [],
245
+ "source" : [
246
+ " # Helps gather stats for our eval() calls later so we can see reward stats.\n " ,
247
+ " env = VecMonitor(env)\n " ,
248
+ " # Attempt to approximate settings from 3DBall.yaml\n " ,
249
+ " model = PPO(\n " ,
250
+ " \" MlpPolicy\" ,\n " ,
251
+ " env,\n " ,
252
+ " verbose=1,\n " ,
253
+ " learning_rate=lambda prog: 0.0003 * (1.0 - prog),\n " ,
254
+ " # Uncomment this if you want to log tensorboard results when running this notebook locally.\n " ,
255
+ " # tensorboard_log=\" results\" ,\n " ,
256
+ " n_steps=int(STEPS_PER_UPDATE),\n " ,
257
+ " )"
258
+ ],
259
+ "metadata" : {
260
+ "collapsed" : false ,
261
+ "pycharm" : {
262
+ "name" : " #%%\n " ,
263
+ "is_executing" : true
264
+ }
265
+ }
266
+ },
267
+ {
268
+ "cell_type" : " markdown" ,
269
+ "source" : [
270
+ " ### Train the model"
271
+ ],
272
+ "metadata" : {
273
+ "collapsed" : false
274
+ }
275
+ },
276
+ {
277
+ "cell_type" : " code" ,
278
+ "execution_count" : null ,
279
+ "outputs" : [],
280
+ "source" : [
281
+ " training_rounds = ceil(TOTAL_TAINING_STEPS_GOAL / int(STEPS_PER_UPDATE * NUM_ENVS))\n " ,
282
+ " for i in range(training_rounds):\n " ,
283
+ " print(f\" Training round {i + 1}/{training_rounds}\" )\n " ,
284
+ " # NOTE: rest_num_timesteps should only happen the first time so that tensorboard logs are consistent.\n " ,
285
+ " model.learn(total_timesteps=6000, reset_num_timesteps=(i == 0))\n " ,
286
+ " model.policy.eval()"
287
+ ],
288
+ "metadata" : {
289
+ "collapsed" : false ,
290
+ "pycharm" : {
291
+ "name" : " #%%\n " ,
292
+ "is_executing" : true
293
+ }
294
+ }
295
+ },
296
+ {
297
+ "cell_type" : " markdown" ,
298
+ "metadata" : {
299
+ "id" : " h1lIx3_l24OP"
300
+ },
301
+ "source" : [
302
+ " ### Close the environment\n " ,
303
+ " Frees up the ports being used."
304
+ ]
305
+ },
306
+ {
307
+ "cell_type" : " code" ,
308
+ "metadata" : {
309
+ "id" : " vdWG6_SqtNtv" ,
310
+ "pycharm" : {
311
+ "is_executing" : true ,
312
+ "name" : " #%%\n "
313
+ }
314
+ },
315
+ "source" : [
316
+ " env.close()\n " ,
317
+ " print(\" Closed environment\" )\n "
318
+ ],
319
+ "execution_count" : null ,
320
+ "outputs" : []
321
+ }
322
+ ]
323
+ }
0 commit comments