Skip to content

Commit 56c153f

Browse files
Dev1nWaraffin
andauthored
Add warning when using PPO on GPU and update doc (#2017)
* Update documentation Added comment to PPO documentation that CPU should primarily be used unless using CNN as well as sample code. Added warning to user for both PPO and A2C that CPU should be used if the user is running GPU without using a CNN, reference Issue #1245. * Add warning to base class and add test --------- Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 512eea9 commit 56c153f

File tree

5 files changed

+56
-4
lines changed

5 files changed

+56
-4
lines changed

docs/misc/changelog.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Changelog
44
==========
55

6-
Release 2.4.0a9 (WIP)
6+
Release 2.4.0a10 (WIP)
77
--------------------------
88

99
.. note::
@@ -60,12 +60,14 @@ Others:
6060
- Fixed various typos (@cschindlbeck)
6161
- Remove unnecessary SDE noise resampling in PPO update (@brn-dev)
6262
- Updated PyTorch version on CI to 2.3.1
63+
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy``
6364

6465
Bug Fixes:
6566
^^^^^^^^^^
6667

6768
Documentation:
6869
^^^^^^^^^^^^^^
70+
- Updated PPO doc to recommend using CPU with ``MlpPolicy``
6971

7072
Release 2.3.2 (2024-04-27)
7173
--------------------------

docs/modules/ppo.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,23 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.
8888
vec_env.render("human")
8989
9090
91+
.. note::
92+
93+
PPO is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``:
94+
95+
.. code-block::
96+
97+
from stable_baselines3 import PPO
98+
from stable_baselines3.common.env_util import make_vec_env
99+
from stable_baselines3.common.vec_env import SubprocVecEnv
100+
101+
if __name__=="__main__":
102+
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
103+
model = PPO("MlpPolicy", env, device="cpu")
104+
model.learn(total_timesteps=25_000)
105+
106+
For more information, see :ref:`Vectorized Environments <vec_env>`, `Issue #1245 <https://github.com/DLR-RM/stable-baselines3/issues/1245#issuecomment-1435766949>`_ or the `Multiprocessing notebook <https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb>`_.
107+
91108
Results
92109
-------
93110

stable_baselines3/common/on_policy_algorithm.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import sys
22
import time
3+
import warnings
34
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
45

56
import numpy as np
@@ -135,6 +136,28 @@ def _setup_model(self) -> None:
135136
self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
136137
)
137138
self.policy = self.policy.to(self.device)
139+
# Warn when not using CPU with MlpPolicy
140+
self._maybe_recommend_cpu()
141+
142+
def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None:
143+
"""
144+
Recommend to use CPU only when using A2C/PPO with MlpPolicy.
145+
146+
:param: The name of the class for the default MlpPolicy.
147+
"""
148+
policy_class_name = self.policy_class.__name__
149+
if self.device != th.device("cpu") and policy_class_name == mlp_class_name:
150+
warnings.warn(
151+
f"You are trying to run {self.__class__.__name__} on the GPU, "
152+
"but it is primarily intended to run on the CPU when not using a CNN policy "
153+
f"(you are using {policy_class_name} which should be a MlpPolicy). "
154+
"See https://github.com/DLR-RM/stable-baselines3/issues/1245 "
155+
"for more info. "
156+
"You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU."
157+
"Note: The model will train, but the GPU utilization will be poor and "
158+
"the training might take longer than on CPU.",
159+
UserWarning,
160+
)
138161

139162
def collect_rollouts(
140163
self,

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.4.0a9
1+
2.4.0a10

tests/test_run.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import gymnasium as gym
22
import numpy as np
33
import pytest
4+
import torch as th
45

56
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
67
from stable_baselines3.common.env_util import make_vec_env
@@ -211,8 +212,11 @@ def test_warn_dqn_multi_env():
211212

212213

213214
def test_ppo_warnings():
214-
"""Test that PPO warns and errors correctly on
215-
problematic rollout buffer sizes"""
215+
"""
216+
Test that PPO warns and errors correctly on
217+
problematic rollout buffer sizes,
218+
and recommend using CPU.
219+
"""
216220

217221
# Only 1 step: advantage normalization will return NaN
218222
with pytest.raises(AssertionError):
@@ -234,3 +238,9 @@ def test_ppo_warnings():
234238
loss = model.logger.name_to_value["train/loss"]
235239
assert loss > 0
236240
assert not np.isnan(loss) # check not nan (since nan does not equal nan)
241+
242+
with pytest.warns(UserWarning, match="You are trying to run PPO on the GPU"):
243+
model = PPO("MlpPolicy", "Pendulum-v1")
244+
# Pretend to be on the GPU
245+
model.device = th.device("cuda")
246+
model._maybe_recommend_cpu()

0 commit comments

Comments
 (0)