Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions examples/ppo_breakout_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.evaluation import evaluate_policy
import gymnasium as gym
import ale_py # the actual ALE backend
import shimmy # registers ALE-py with Gymnasium

# 1. Create a vectorized Atari environment and stack frames
env = make_atari_env("ALE/Breakout-v5", n_envs=4, seed=0)
env = VecFrameStack(env, n_stack=4)

# 2. Instantiate PPO with a CNN policy
model = PPO(
policy="CnnPolicy",
env=env,
verbose=1
)
# 3. Train the agent
model.learn(total_timesteps=200_000)

# 4. Save the trained model
model.save("ppo_breakout_cnn")

# 5. Evaluate the trained agent
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
print(f"Mean reward: {mean_reward:.2f} ± {std_reward:.2f}")

# 6. (Optional) Run a few episodes and render
obs = env.reset()
for _ in range(10_000):
action, _states = model.predict(obs)
obs, rewards, dones, infos = env.step(action)
env.render()
if dones.any():
obs = env.reset()
136 changes: 52 additions & 84 deletions stable_baselines3/common/torch_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,41 +67,44 @@ def __init__(
observation_space: gym.Space,
features_dim: int = 512,
normalized_image: bool = False,
use_batch_norm: bool = False,
) -> None:
assert isinstance(observation_space, spaces.Box), (
"NatureCNN must be used with a gym.spaces.Box ",
f"observation space, not {observation_space}",
)
assert isinstance(observation_space, spaces.Box), "NatureCNN only supports Box"
super().__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
assert is_image_space(observation_space, check_channels=False, normalized_image=normalized_image), (
"You should use NatureCNN "
f"only with images not with {observation_space}\n"
"(you are probably using `CnnPolicy` instead of `MlpPolicy` or `MultiInputPolicy`)\n"
"If you are using a custom environment,\n"
"please check it using our env checker:\n"
"https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html.\n"
"If you are using `VecNormalize` or already normalized channel-first images "
"you should pass `normalize_images=False`: \n"
"https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html"
)
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
)

# Compute shape by doing one forward pass
assert is_image_space(
observation_space, check_channels=False, normalized_image=normalized_image
), "NatureCNN only supports image spaces"

n_in = observation_space.shape[0]
layers: list[nn.Module] = []

# conv1
layers.append(nn.Conv2d(n_in, 32, kernel_size=8, stride=4))
if use_batch_norm:
layers.append(nn.BatchNorm2d(32))
layers.append(nn.ReLU())

# conv2
layers.append(nn.Conv2d(32, 64, kernel_size=4, stride=2))
if use_batch_norm:
layers.append(nn.BatchNorm2d(64))
layers.append(nn.ReLU())

# conv3
layers.append(nn.Conv2d(64, 64, kernel_size=3, stride=1))
if use_batch_norm:
layers.append(nn.BatchNorm2d(64))
layers.append(nn.ReLU())

layers.append(nn.Flatten())
self.cnn = nn.Sequential(*layers)

with th.no_grad():
n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]
sample = observation_space.sample()[None]
n_flat = self.cnn(th.as_tensor(sample).float()).shape[1]

self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
linear_layers = [nn.Linear(n_flat, features_dim), nn.ReLU()]
self.linear = nn.Sequential(*linear_layers)

def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))
Expand All @@ -116,70 +119,35 @@ def create_mlp(
with_bias: bool = True,
pre_linear_modules: Optional[list[type[nn.Module]]] = None,
post_linear_modules: Optional[list[type[nn.Module]]] = None,
use_batch_norm: bool = False,
) -> list[nn.Module]:
"""
Create a multi layer perceptron (MLP), which is
a collection of fully-connected layers each followed by an activation function.

:param input_dim: Dimension of the input vector
:param output_dim: Dimension of the output (last layer, for instance, the number of actions)
:param net_arch: Architecture of the neural net
It represents the number of units per layer.
The length of this list is the number of layers.
:param activation_fn: The activation function
to use after each layer.
:param squash_output: Whether to squash the output using a Tanh
activation function
:param with_bias: If set to False, the layers will not learn an additive bias
:param pre_linear_modules: List of nn.Module to add before the linear layers.
These modules should maintain the input tensor dimension (e.g. BatchNorm).
The number of input features is passed to the module's constructor.
Compared to post_linear_modules, they are used before the output layer (output_dim > 0).
:param post_linear_modules: List of nn.Module to add after the linear layers
(and before the activation function). These modules should maintain the input
tensor dimension (e.g. Dropout, LayerNorm). They are not used after the
output layer (output_dim > 0). The number of input features is passed to
the module's constructor.
:return: The list of layers of the neural network
"""
modules: list[nn.Module] = []

if use_batch_norm:
modules.append(nn.BatchNorm1d(input_dim))

pre_linear_modules = pre_linear_modules or []
post_linear_modules = post_linear_modules or []

modules = []
if len(net_arch) > 0:
# BatchNorm maintains input dim
for module in pre_linear_modules:
modules.append(module(input_dim))

modules.append(nn.Linear(input_dim, net_arch[0], bias=with_bias))

# LayerNorm, Dropout maintain output dim
for module in post_linear_modules:
modules.append(module(net_arch[0]))

modules.append(activation_fn())

for idx in range(len(net_arch) - 1):
for module in pre_linear_modules:
modules.append(module(net_arch[idx]))

modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1], bias=with_bias))

for module in post_linear_modules:
modules.append(module(net_arch[idx + 1]))

last_dim = input_dim
for layer_size in net_arch:
for mod in pre_linear_modules:
modules.append(mod(last_dim))
modules.append(nn.Linear(last_dim, layer_size, bias=with_bias))
for mod in post_linear_modules:
modules.append(mod(layer_size))
modules.append(activation_fn())
last_dim = layer_size

if output_dim > 0:
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
# Only add BatchNorm before output layer
for module in pre_linear_modules:
modules.append(module(last_layer_dim))
for mod in pre_linear_modules:
modules.append(mod(last_dim))
modules.append(nn.Linear(last_dim, output_dim, bias=with_bias))
last_dim = output_dim

modules.append(nn.Linear(last_layer_dim, output_dim, bias=with_bias))
if squash_output:
modules.append(nn.Tanh())

return modules


Expand Down