Skip to content

Commit 8279509

Browse files
committed
finalizing lab3 2021
1 parent e055a14 commit 8279509

File tree

1 file changed

+59
-36
lines changed

1 file changed

+59
-36
lines changed

lab3/solutions/RL_Solution.ipynb

Lines changed: 59 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"nbformat_minor": 0,
44
"metadata": {
55
"colab": {
6-
"name": "RL_Solution_APS.ipynb",
6+
"name": "RL_Solution.ipynb",
77
"provenance": [],
88
"collapsed_sections": []
99
},
@@ -109,7 +109,7 @@
109109
"import base64, io, time, gym\n",
110110
"import IPython, functools\n",
111111
"import matplotlib.pyplot as plt\n",
112-
"import copy\n",
112+
"import time\n",
113113
"from tqdm import tqdm\n",
114114
"\n",
115115
"# !pip install mitdeeplearning\n",
@@ -624,7 +624,9 @@
624624
"id": "lbYHLr66i15n"
625625
},
626626
"source": [
627-
"env = gym.make(\"Pong-v0\", frameskip=5)\n",
627+
"def create_pong_env(): \n",
628+
" return gym.make(\"Pong-v0\", frameskip=5)\n",
629+
"env = create_pong_env()\n",
628630
"env.seed(1); # for reproducibility"
629631
],
630632
"execution_count": null,
@@ -819,7 +821,9 @@
819821
"id": "YBLVfdpv7ajG"
820822
},
821823
"source": [
822-
"Let's also consider the fact that, unlike CartPole, the Pong environment is a *dynamic* one -- that is, the environment is changing over time, based on the actions we take and the actions of the opponent, which result in motion of the ball and motion of the paddles. Therefore, to capture the dynamics, we also consider how the environment changes by looking at the difference between a previous observation (image frame) and the current observation (image frame). We've implemented a helper function, `pong_change`, that pre-processes two frames, calculates the change between the two, and then re-normalizes the values. Let's inspect this to visualize how the environment can change:"
824+
"Let's also consider the fact that, unlike CartPole, the Pong environment has an additional element of uncertainty -- regardless of what action the agent takes, we don't know how the opponent will play. That is, the environment is changing over time, based on *both* the actions we take and the actions of the opponent, which result in motion of the ball and motion of the paddles.\r\n",
825+
"\r\n",
826+
"Therefore, to capture the dynamics, we also consider how the environment changes by looking at the difference between a previous observation (image frame) and the current observation (image frame). We've implemented a helper function, `pong_change`, that pre-processes two frames, calculates the change between the two, and then re-normalizes the values. Let's inspect this to visualize how the environment can change:"
823827
]
824828
},
825829
{
@@ -837,7 +841,7 @@
837841
" a.axis(\"off\")\r\n",
838842
"ax[0].imshow(observation); ax[0].set_title('Previous Frame');\r\n",
839843
"ax[1].imshow(next_observation); ax[1].set_title('Current Frame');\r\n",
840-
"ax[2].imshow(np.squeeze(diff)); ax[2].set_title('Difference');"
844+
"ax[2].imshow(np.squeeze(diff)); ax[2].set_title('Difference (Model Input)');"
841845
],
842846
"execution_count": null,
843847
"outputs": []
@@ -986,55 +990,72 @@
986990
"Let's run the code block to train our Pong agent. Note that, even with parallelization, completing training and getting stable behavior will take quite a bit of time (estimated at least a couple of hours). We will again visualize the evolution of the total reward as a function of training to get a sense of how the agent is learning."
987991
]
988992
},
993+
{
994+
"cell_type": "code",
995+
"metadata": {
996+
"id": "FaEHTMRVMRXP"
997+
},
998+
"source": [
999+
"### Hyperparameters and setup for training ###\r\n",
1000+
"# Rerun this cell if you want to re-initialize the training process\r\n",
1001+
"# (i.e., create new model, reset loss, etc)\r\n",
1002+
"\r\n",
1003+
"# Hyperparameters\r\n",
1004+
"learning_rate = 1e-3\r\n",
1005+
"MAX_ITERS = 1000 # increase the maximum to train longer\r\n",
1006+
"batch_size = 5 # number of batches to run\r\n",
1007+
"\r\n",
1008+
"# Model, optimizer\r\n",
1009+
"pong_model = create_pong_model()\r\n",
1010+
"optimizer = tf.keras.optimizers.Adam(learning_rate)\r\n",
1011+
"iteration = 0 # counter for training steps\r\n",
1012+
"\r\n",
1013+
"# Plotting\r\n",
1014+
"smoothed_reward = mdl.util.LossHistory(smoothing_factor=0.9)\r\n",
1015+
"smoothed_reward.append(0) # start the reward at zero for baseline comparison\r\n",
1016+
"plotter = mdl.util.PeriodicPlotter(sec=15, xlabel='Iterations', ylabel='Win Percentage (%)')\r\n",
1017+
"\r\n",
1018+
"# Batches and environment\r\n",
1019+
"# To parallelize batches, we need to make multiple copies of the environment.\r\n",
1020+
"envs = [create_pong_env() for _ in range(batch_size)] # For parallelization"
1021+
],
1022+
"execution_count": null,
1023+
"outputs": []
1024+
},
9891025
{
9901026
"cell_type": "code",
9911027
"metadata": {
9921028
"id": "xCwyQQrPnkZG"
9931029
},
9941030
"source": [
9951031
"### Training Pong ###\n",
996-
"\n",
997-
"# Hyperparameters\n",
998-
"learning_rate = 1e-3\n",
999-
"MAX_ITERS = 1000 # increase the maximum to train longer\n",
1032+
"# You can run this cell and stop it anytime in the middle of training to save \n",
1033+
"# a progress video (see next codeblock). To continue training, simply run this \n",
1034+
"# cell again, your model will pick up right where it left off. To reset training,\n",
1035+
"# you need to run the cell above. \n",
10001036
"\n",
10011037
"games_to_win_episode = 21 # this is set by OpenAI gym and cannot be changed.\n",
10021038
"\n",
1003-
"# Model, optimizer\n",
1004-
"pong_model = create_pong_model()\n",
1005-
"optimizer = tf.keras.optimizers.Adam(learning_rate)\n",
1006-
"\n",
1007-
"# Plotting\n",
1008-
"smoothed_reward = mdl.util.LossHistory(smoothing_factor=0.9)\n",
1009-
"smoothed_reward.append(0) # start the reward at zero for baseline comparison\n",
1010-
"plotter = mdl.util.PeriodicPlotter(sec=20, xlabel='Iterations', ylabel='Win Percentage (%)')\n",
1011-
"\n",
1012-
"# Batches and environment\n",
1013-
"batch_size = 5 # number of batches to run\n",
1014-
"# To parallelize batches, we need to make multiple copies of the environment.\n",
1015-
"envs = [copy.deepcopy(env) for _ in range(batch_size)] # For parallelization\n",
1016-
"\n",
10171039
"# Main training loop\n",
1018-
"for i_episode in range(MAX_ITERS):\n",
1040+
"while iteration < MAX_ITERS:\n",
10191041
"\n",
10201042
" plotter.plot(smoothed_reward.get())\n",
10211043
"\n",
1044+
" tic = time.time()\n",
10221045
" # RL agent algorithm. By default, uses serial batch processing.\n",
10231046
" # memories = collect_rollout(batch_size, env, pong_model, choose_action)\n",
10241047
"\n",
10251048
" # Parallelized version. Uncomment line below (and comment out line above) to parallelize\n",
10261049
" memories = mdl.lab3.parallelized_collect_rollout(batch_size, envs, pong_model, choose_action)\n",
1050+
" print(time.time()-tic)\n",
10271051
"\n",
10281052
" # Aggregate memories from multiple batches\n",
10291053
" batch_memory = aggregate_memories(memories)\n",
10301054
"\n",
1031-
" # Determine total reward and track reported as win percentage\n",
1032-
" # net_score = sum(batch_memory.rewards) / batch_size\n",
1033-
" # win_rate = abs(net_score) / games_to_win_episode\n",
1055+
" # Track performance based on win percentage (calculated from rewards)\n",
10341056
" total_wins = sum(np.array(batch_memory.rewards) == 1)\n",
10351057
" total_games = sum(np.abs(np.array(batch_memory.rewards)))\n",
10361058
" win_rate = total_wins / total_games\n",
1037-
"\n",
10381059
" smoothed_reward.append(100 * win_rate)\n",
10391060
" \n",
10401061
" # Training!\n",
@@ -1047,9 +1068,11 @@
10471068
" )\n",
10481069
"\n",
10491070
" # Save a video of progress -- this can be played back later\n",
1050-
" if i_episode % 500 == 0:\n",
1071+
" if iteration % 100 == 0:\n",
10511072
" mdl.lab3.save_video_of_model(pong_model, \"Pong-v0\", \n",
1052-
" suffix=\"_\"+str(i_episode))\n"
1073+
" suffix=\"_\"+str(iteration))\n",
1074+
" \n",
1075+
" iteration += 1 # Mark next episode"
10531076
],
10541077
"execution_count": null,
10551078
"outputs": []
@@ -1069,9 +1092,9 @@
10691092
"id": "TvHXbkL0tR6M"
10701093
},
10711094
"source": [
1072-
"final_pong = mdl.lab3.save_video_of_model(\n",
1073-
" pong_model, \"Pong-v0\", suffix=\"final\")\n",
1074-
"mdl.lab3.play_video(final_pong)"
1095+
"latest_pong = mdl.lab3.save_video_of_model(\n",
1096+
" pong_model, \"Pong-v0\", suffix=\"latest\")\n",
1097+
"mdl.lab3.play_video(latest_pong)"
10751098
],
10761099
"execution_count": null,
10771100
"outputs": []
@@ -1088,11 +1111,11 @@
10881111
"\n",
10891112
"* How does the agent perform?\n",
10901113
"* Could you train it for shorter amounts of time and still perform well?\n",
1091-
"* Do you think that training longer would help even more? \n",
1092-
"* How does the complexity of Pong relative to Cartpole alter the rate at which the agent learns and its performance? \n",
1114+
"* What are some limitations of the current representation i.e., difference of current and previous frames? How is this reflected in the agent's behavior? What could be done to generate an improved representation?\n",
1115+
"* How does the complexity of Pong relative to CartPole alter the rate at which the agent learns and its performance? \n",
10931116
"* What are some things you could change about the agent or the learning process to potentially improve performance?\n",
10941117
"\n",
1095-
"Try to optimize your model to achieve improved performance. **MIT students and affiliates will be eligible for prizes during the IAP offering.** To enter the competition, please [email us](mailto:[email protected]) with your name and the following:\n",
1118+
"Try to optimize your **Pong** model and algorithm to achieve improved performance. **MIT students and affiliates will be eligible for prizes during the IAP offering.** To enter the competition, please [email us](mailto:[email protected]) with your name and the following:\n",
10961119
"\n",
10971120
"* Jupyter notebook with the code you used to generate your results, **with the Pong training executed**;\n",
10981121
"* saved video of your Pong agent competing;\n",

0 commit comments

Comments
 (0)