Skip to content

Commit 3d59b5c

Browse files
authored
Use uv on GitHub CI for faster download and update changelog (#2026)
* Use uv on GitHub CI for faster download and update changelog * Fix new mypy issues
1 parent 56c153f commit 3d59b5c

File tree

5 files changed

+18
-7
lines changed

5 files changed

+18
-7
lines changed

.github/workflows/ci.yml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,21 @@ jobs:
3131
- name: Install dependencies
3232
run: |
3333
python -m pip install --upgrade pip
34+
# Use uv for faster downloads
35+
pip install uv
3436
# cpu version of pytorch
35-
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
37+
# See https://github.com/astral-sh/uv/issues/1497
38+
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu
3639
3740
# Install Atari Roms
38-
pip install autorom
41+
uv pip install --system autorom
3942
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
4043
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
4144
AutoROM --accept-license --source-file Roms.tar.gz
4245
43-
pip install .[extra_no_roms,tests,docs]
46+
uv pip install --system .[extra_no_roms,tests,docs]
4447
# Use headless version
45-
pip install opencv-python-headless
48+
uv pip install --system opencv-python-headless
4649
- name: Lint with ruff
4750
run: |
4851
make lint

docs/guide/sb3_contrib.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ See documentation for the full list of included features.
4242
- `PPO with recurrent policy (RecurrentPPO aka PPO LSTM) <https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/>`_
4343
- `Truncated Quantile Critics (TQC)`_
4444
- `Trust Region Policy Optimization (TRPO) <https://arxiv.org/abs/1502.05477>`_
45+
- `Batch Normalization in Deep Reinforcement Learning (CrossQ) <https://openreview.net/forum?id=PczQtTsTIX>`_
4546

4647

4748
**Gym Wrappers**:

docs/misc/changelog.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ Changelog
66
Release 2.4.0a10 (WIP)
77
--------------------------
88

9+
**New algorithm: CrossQ in SB3 Contrib**
10+
911
.. note::
1012

1113
DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about
@@ -43,6 +45,10 @@ Bug Fixes:
4345

4446
`SB3-Contrib`_
4547
^^^^^^^^^^^^^^
48+
- Added ``CrossQ`` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen)
49+
- Added ``BatchRenorm`` PyTorch layer used in ``CrossQ`` (@danielpalen)
50+
- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger)
51+
- Fixed loading QRDQN changes `target_update_interval` (@jak3122)
4652

4753
`RL Zoo`_
4854
^^^^^^^^^
@@ -61,6 +67,7 @@ Others:
6167
- Remove unnecessary SDE noise resampling in PPO update (@brn-dev)
6268
- Updated PyTorch version on CI to 2.3.1
6369
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy``
70+
- Switched to uv to download packages faster on GitHub CI
6471

6572
Bug Fixes:
6673
^^^^^^^^^^

stable_baselines3/common/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def set_random_seed(seed: int, using_cuda: bool = False) -> None:
4646

4747

4848
# From stable baselines
49-
def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
49+
def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> float:
5050
"""
5151
Computes fraction of variance that ypred explains about y.
5252
Returns 1 - Var[y-ypred] / Var[y]
@@ -62,7 +62,7 @@ def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
6262
"""
6363
assert y_true.ndim == 1 and y_pred.ndim == 1
6464
var_y = np.var(y_true)
65-
return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
65+
return np.nan if var_y == 0 else float(1 - np.var(y_true - y_pred) / var_y)
6666

6767

6868
def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None:

tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def test_custom_vec_env(tmp_path):
177177

178178

179179
@pytest.mark.parametrize("direct_policy", [False, True])
180-
def test_evaluate_policy(direct_policy: bool):
180+
def test_evaluate_policy(direct_policy):
181181
model = A2C("MlpPolicy", "Pendulum-v1", seed=0)
182182
n_steps_per_episode, n_eval_episodes = 200, 2
183183

0 commit comments

Comments
 (0)