Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ https://pytorch.org/examples/
- [Variational Auto-Encoders](./vae/README.md)
- [Superresolution using an efficient sub-pixel convolutional neural network](./super_resolution/README.md)
- [Hogwild training of shared ConvNets across multiple processes on MNIST](mnist_hogwild)
- [Training a CartPole to balance in OpenAI Gym with actor-critic](./reinforcement_learning/README.md)
- [Training a CartPole to balance with actor-critic](./reinforcement_learning/README.md)
- [Natural Language Inference (SNLI) with GloVe vectors, LSTMs, and torchtext](snli)
- [Time sequence prediction - use an LSTM to learn Sine waves](./time_sequence_prediction/README.md)
- [Implement the Neural Style Transfer algorithm on images](./fast_neural_style/README.md)
Expand Down
11 changes: 5 additions & 6 deletions distributed/rpc/batch/reinforce.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import argparse
import gym
import gymnasium as gym
import os
import threading
import time
Expand Down Expand Up @@ -68,7 +68,7 @@ class Observer:
def __init__(self, batch=True):
self.id = rpc.get_worker_info().id - 1
self.env = gym.make('CartPole-v1')
self.env.seed(args.seed)
self.env.reset(seed=args.seed)
self.select_action = Agent.select_action_batch if batch else Agent.select_action

def run_episode(self, agent_rref, n_steps):
Expand All @@ -92,10 +92,10 @@ def run_episode(self, agent_rref, n_steps):
)

# apply the action to the environment, and get the reward
state, reward, done, _ = self.env.step(action)
state, reward, terminated, truncated, _ = self.env.step(action)
rewards[step] = reward

if done or step + 1 >= n_steps:
if terminated or truncated or step + 1 >= n_steps:
curr_rewards = rewards[start_step:(step + 1)]
R = 0
for i in range(curr_rewards.numel() -1, -1, -1):
Expand Down Expand Up @@ -226,8 +226,7 @@ def run_worker(rank, world_size, n_episode, batch, print_log=True):
last_reward, running_reward = agent.run_episode(n_steps=NUM_STEPS)

if print_log:
print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
i_episode, last_reward, running_reward))
print(f'Episode {i_episode}\tLast reward: {last_reward:.2f}\tAverage reward: {running_reward:.2f}')
else:
# other ranks are the observer
rpc.init_rpc(OBSERVER_NAME.format(rank), rank=rank, world_size=world_size)
Expand Down
2 changes: 1 addition & 1 deletion distributed/rpc/batch/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torch==2.2.0
torchvision==0.7.0
numpy
gym
gymnasium
6 changes: 3 additions & 3 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ experiment with PyTorch.
`GO TO EXAMPLE <https://github.com/pytorch/examples/blob/main/mnist_hogwild>`__ :opticon:`link-external`

---
Training a CartPole to balance in OpenAI Gym with actor-critic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Training a CartPole to balance with actor-critic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This reinforcement learning tutorial demonstrates how to train a
CartPole to balance
in the `OpenAI Gym <https://gym.openai.com/>`__ toolkit by using the
in the `Gymnasium <https://gymnasium.farama.org/>`__ toolkit by using the
`Actor-Critic <https://proceedings.neurips.cc/paper/1999/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf>`__ method.

`GO TO EXAMPLE <https://github.com/pytorch/examples/blob/main/reinforcement_learning>`__ :opticon:`link-external`
Expand Down
19 changes: 8 additions & 11 deletions reinforcement_learning/actor_critic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import argparse
import gym
import gymnasium as gym
import numpy as np
from itertools import count
from collections import namedtuple
Expand All @@ -24,7 +24,8 @@
args = parser.parse_args()


env = gym.make('CartPole-v1')
render_mode = "human" if args.render else None
env = gym.make('CartPole-v1', render_mode=render_mode)
env.reset(seed=args.seed)
torch.manual_seed(args.seed)

Expand Down Expand Up @@ -152,14 +153,11 @@ def main():
action = select_action(state)

# take the action
state, reward, done, _, _ = env.step(action)

if args.render:
env.render()
state, reward, terminated, truncated, _ = env.step(action)

model.rewards.append(reward)
ep_reward += reward
if done:
if terminated or truncated:
break

# update cumulative reward
Expand All @@ -170,13 +168,12 @@ def main():

# log results
if i_episode % args.log_interval == 0:
print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
i_episode, ep_reward, running_reward))
print(f'Episode {i_episode}\tLast reward: {ep_reward:.2f}\tAverage reward: {running_reward:.2f}')

# check if we have "solved" the cart pole problem
if running_reward > env.spec.reward_threshold:
print("Solved! Running reward is now {} and "
"the last episode runs to {} time steps!".format(running_reward, t))
print(f"Solved! Running reward is now {running_reward} and "
f"the last episode runs to {t} time steps!")
break


Expand Down
15 changes: 7 additions & 8 deletions reinforcement_learning/reinforce.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import argparse
import gym
import gymnasium as gym
import numpy as np
from itertools import count
from collections import deque
Expand All @@ -22,7 +22,8 @@
args = parser.parse_args()


env = gym.make('CartPole-v1')
render_mode = "human" if args.render else None
env = gym.make('CartPole-v1', render_mode=render_mode)
env.reset(seed=args.seed)
torch.manual_seed(args.seed)

Expand Down Expand Up @@ -85,22 +86,20 @@ def main():
ep_reward = 0
for t in range(1, 10000): # Don't infinite loop while learning
action = select_action(state)
state, reward, done, _, _ = env.step(action)
state, reward, terminated, truncated, _ = env.step(action)
if args.render:
env.render()
policy.rewards.append(reward)
ep_reward += reward
if done:
if terminated or truncated:
break

running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward
finish_episode()
if i_episode % args.log_interval == 0:
print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
i_episode, ep_reward, running_reward))
print(f'Episode {i_episode}\tLast reward: {ep_reward:.2f}\tAverage reward: {running_reward:.2f}')
if running_reward > env.spec.reward_threshold:
print("Solved! Running reward is now {} and "
"the last episode runs to {} time steps!".format(running_reward, t))
print(f"Solved! Running reward is now {running_reward} and the last episode runs to {t} time steps!")
break


Expand Down
5 changes: 2 additions & 3 deletions reinforcement_learning/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
torch
numpy<2
gym
pygame
numpy
gymnasium[classic-control]