Skip to content

Commit fc74bb8

Browse files
chr0niklerchr0nikler
andauthored
Fixup for 'Training An Agent' page (#1281)
Co-authored-by: chr0nikler <jchahal@diffzero.com>
1 parent 87cc458 commit fc74bb8

File tree

1 file changed

+33
-22
lines changed

1 file changed

+33
-22
lines changed

docs/introduction/train_agent.md

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -155,37 +155,48 @@ You can use `matplotlib` to visualize the training reward and length.
155155

156156
```python
157157
from matplotlib import pyplot as plt
158-
# visualize the episode rewards, episode length and training error in one figure
159-
fig, axs = plt.subplots(1, 3, figsize=(20, 8))
160158

161-
# np.convolve will compute the rolling mean for 100 episodes
162-
163-
axs[0].plot(np.convolve(env.return_queue, np.ones(100)/100))
164-
axs[0].set_title("Episode Rewards")
165-
axs[0].set_xlabel("Episode")
166-
axs[0].set_ylabel("Reward")
159+
def get_moving_avgs(arr, window, convolution_mode):
160+
return np.convolve(
161+
np.array(arr).flatten(),
162+
np.ones(window),
163+
mode=convolution_mode
164+
) / window
165+
166+
# Smooth over a 500 episode window
167+
rolling_length = 500
168+
fig, axs = plt.subplots(ncols=3, figsize=(12, 5))
169+
170+
axs[0].set_title("Episode rewards")
171+
reward_moving_average = get_moving_avgs(
172+
env.return_queue,
173+
rolling_length,
174+
"valid"
175+
)
176+
axs[0].plot(range(len(reward_moving_average)), reward_moving_average)
167177

168-
axs[1].plot(np.convolve(env.length_queue, np.ones(100)/100))
169-
axs[1].set_title("Episode Lengths")
170-
axs[1].set_xlabel("Episode")
171-
axs[1].set_ylabel("Length")
178+
axs[1].set_title("Episode lengths")
179+
length_moving_average = get_moving_avgs(
180+
env.length_queue,
181+
rolling_length,
182+
"valid"
183+
)
184+
axs[1].plot(range(len(length_moving_average)), length_moving_average)
172185

173-
axs[2].plot(np.convolve(agent.training_error, np.ones(100)/100))
174186
axs[2].set_title("Training Error")
175-
axs[2].set_xlabel("Episode")
176-
axs[2].set_ylabel("Temporal Difference")
177-
187+
training_error_moving_average = get_moving_avgs(
188+
agent.training_error,
189+
rolling_length,
190+
"same"
191+
)
192+
axs[2].plot(range(len(training_error_moving_average)), training_error_moving_average)
178193
plt.tight_layout()
179194
plt.show()
180-
```
181-
182-
![](../_static/img/tutorials/blackjack_training_plots.png "Training Plot")
183195

184-
## Visualising the policy
185196

186-
![](../_static/img/tutorials/blackjack_with_usable_ace.png "With a usable ace")
197+
```
187198

188-
![](../_static/img/tutorials/blackjack_without_usable_ace.png "Without a usable ace")
199+
![](../_static/img/tutorials/blackjack_training_plots.png "Training Plot")
189200

190201
Hopefully this tutorial helped you get a grip of how to interact with Gymnasium environments and sets you on a journey to solve many more RL challenges.
191202

0 commit comments

Comments
 (0)