-
Notifications
You must be signed in to change notification settings - Fork 137
Description
Describe the bug
I have observed a considerable decrease in policy performance after the recent PyTorch 2.5.0 update. The decrease in performance replicates when training with A2C, REINFORCE and PPO.
Before: brown. After: purple. Same environment model, same random seeds.

To Reproduce
Install RL4CO and other dependencies using the following Conda environment.yaml:
name: rl
channels:
- conda-forge
- defaults
dependencies:
- pip
- python=3.12.7
- pip:
- rl4co
# data analysis
- polars
- pandas
# data visualization
- matplotlib
- seaborn
# logging
- tensorboard
Previous result when creating the environment
Approximately 3 days ago this would've installed the following dependencies:
INSTALLED VERSIONS
-------------------------------------
rl4co : 0.5.0
torch : 2.4.1+cu121
lightning : 2.4.0
torchrl : 0.5.0
tensordict : 0.5.0
numpy : 2.1.2
pytorch_geometric : Not installed
hydra-core : 1.3.2
omegaconf : 2.3.0
matplotlib : 3.9.2
Python : 3.12.7
Platform : Linux-5.15.0-78-generic-x86_64-with-glibc2.35
Lightning device : cuda
This environment can be replicated with the following environment.yaml:
name: rl
channels:
- conda-forge
- defaults
dependencies:
- python=3.12.7
- pip
- pip:
- -r requirements.txt
where requirements.txt must be stored in the same directory as environment.yaml and contain:
setuptools==75.1.0
wheel==0.44.0
pip==24.2
pytz==2024.2
mpmath==1.3.0
antlr4-python3-runtime==4.9.3
urllib3==2.2.3
tzdata==2024.2
typing_extensions==4.12.2
tqdm==4.66.5
tensorboard-data-server==0.7.2
sympy==1.13.3
smmap==5.0.1
six==1.16.0
setproctitle==1.3.3
PyYAML==6.0.2
python-dotenv==1.0.1
pyparsing==3.2.0
Pygments==2.18.0
psutil==6.0.0
protobuf==5.28.2
propcache==0.2.0
polars==1.9.0
platformdirs==4.3.6
pillow==11.0.0
packaging==24.1
orjson==3.10.7
nvidia-nvtx-cu12==12.1.105
nvidia-nvjitlink-cu12==12.6.77
nvidia-nccl-cu12==2.20.5
nvidia-curand-cu12==10.3.2.106
nvidia-cufft-cu12==11.0.2.54
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cublas-cu12==12.1.3.1
numpy==2.1.2
networkx==3.4.1
multidict==6.1.0
mdurl==0.1.2
MarkupSafe==3.0.1
Markdown==3.7
kiwisolver==1.4.7
idna==3.10
grpcio==1.67.0
fsspec==2024.9.0
frozenlist==1.4.1
fonttools==4.54.1
filelock==3.16.1
einops==0.8.0
cycler==0.12.1
colorlog==6.8.2
cloudpickle==3.1.0
click==8.1.7
charset-normalizer==3.4.0
certifi==2024.8.30
attrs==24.2.0
aiohappyeyeballs==2.4.3
absl-py==2.1.0
yarl==1.15.4
Werkzeug==3.0.4
triton==3.0.0
sentry-sdk==2.17.0
scipy==1.14.1
requests==2.32.3
python-dateutil==2.9.0.post0
pyrootutils==1.0.4
omegaconf==2.3.0
nvidia-cusparse-cu12==12.1.0.106
nvidia-cudnn-cu12==9.1.0.70
markdown-it-py==3.0.0
lightning-utilities==0.11.8
Jinja2==3.1.4
gitdb==4.0.11
docker-pycreds==0.4.0
contourpy==1.3.0
aiosignal==1.3.1
tensorboard==2.18.0
robust-downloader==0.0.2
rich==13.9.2
pandas==2.2.3
nvidia-cusolver-cu12==11.4.5.107
matplotlib==3.9.2
hydra-core==1.3.2
GitPython==3.1.43
aiohttp==3.10.10
wandb==0.18.3
torch==2.4.1
seaborn==0.13.2
hydra-colorlog==1.2.0
torchmetrics==1.4.3
tensordict==0.5.0
torchrl==0.5.0
pytorch-lightning==2.4.0
lightning==2.4.0
rl4co==0.5.0
pyDOE3==1.0.4
statsmodels==0.14.4
Current result when creating the environment
As of today it installs the following dependencies, including PyTorch 2.5.0:
INSTALLED VERSIONS
-------------------------------------
rl4co : 0.5.0
torch : 2.5.0+cu124
lightning : 2.4.0
torchrl : 0.5.0
tensordict : 0.5.0
numpy : 1.26.4
pytorch_geometric : Not installed
hydra-core : 1.3.2
omegaconf : 2.3.0
matplotlib : 3.9.2
Python : 3.12.7
Platform : Linux-6.5.0-35-generic-x86_64-with-glibc2.35
Lightning device : cuda
Detailed list of dependencies
The following is a detailed list of all different dependencies between the environment created 3 days ago and the current one. I believe PyTorch 2.5.0 is the main culprit here.
| Library | Version in File 1 | Version in File 2 |
|---|---|---|
| PyYAML | 6.0.2 | 6.0.2 |
| Pygments | 2.18.0 | (missing) |
| absl-py | 2.1.0 | 2.1.0 |
| aiohappyeyeballs | 2.4.3 | 2.4.3 |
| aiohttp | 3.10.10 | 3.10.10 |
| aiosignal | 1.3.1 | 1.3.1 |
| antlr4-python3-runtime | 4.9.3 | 4.9.3 |
| attrs | 24.2.0 | 24.2.0 |
| certifi | 2024.8.30 | 2024.8.30 |
| charset-normalizer | 3.4.0 | 3.4.0 |
| click | 8.1.7 | 8.1.7 |
| cloudpickle | 3.1.0 | 3.1.0 |
| colorlog | 6.8.2 | 6.8.2 |
| contourpy | 1.3.0 | 1.3.0 |
| cycler | 0.12.1 | 0.12.1 |
| docker-pycreds | 0.4.0 | 0.4.0 |
| einops | 0.8.0 | 0.8.0 |
| filelock | 3.16.1 | 3.16.1 |
| fonttools | 4.54.1 | 4.54.1 |
| frozenlist | 1.4.1 | 1.4.1 |
| fsspec | 2024.9.0 | 2024.10.0 |
| gitdb | 4.0.11 | 4.0.11 |
| GitPython | 3.1.43 | 3.1.43 |
| grpcio | 1.67.0 | 1.67.0 |
| hydra-colorlog | 1.2.0 | 1.2.0 |
| hydra-core | 1.3.2 | 1.3.2 |
| idna | 3.10 | 3.10 |
| Jinja2 | 3.1.4 | 3.1.4 |
| kiwisolver | 1.4.7 | 1.4.7 |
| lightning | 2.4.0 | 2.4.0 |
| lightning-utilities | 0.11.8 | 0.11.8 |
| Markdown | 3.7 | 3.7 |
| markdown-it-py | 3.0.0 | 3.0.0 |
| MarkupSafe | 3.0.1 | 3.0.2 |
| matplotlib | 3.9.2 | 3.9.2 |
| mdurl | 0.1.2 | 0.1.2 |
| mpmath | 1.3.0 | 1.3.0 |
| multidict | 6.1.0 | 6.1.0 |
| networkx | 3.4.1 | 3.4.1 |
| numpy | 2.1.2 | 1.26.4 |
| nvidia-cublas-cu12 | 12.1.3.1 | 12.4.5.8 |
| nvidia-cuda-cupti-cu12 | 12.1.105 | 12.4.127 |
| nvidia-cuda-nvrtc-cu12 | 12.1.105 | 12.4.127 |
| nvidia-cuda-runtime-cu12 | 12.1.105 | 12.4.127 |
| nvidia-cudnn-cu12 | 9.1.0.70 | 9.1.0.70 |
| nvidia-cufft-cu12 | 11.0.2.54 | 11.2.1.3 |
| nvidia-curand-cu12 | 10.3.2.106 | 10.3.5.147 |
| nvidia-cusolver-cu12 | 11.4.5.107 | 11.6.1.9 |
| nvidia-cusparse-cu12 | 12.1.0.106 | 12.3.1.170 |
| nvidia-nccl-cu12 | 2.20.5 | 2.21.5 |
| nvidia-nvjitlink-cu12 | 12.6.77 | 12.4.127 |
| nvidia-nvtx-cu12 | 12.1.105 | 12.4.127 |
| omegaconf | 2.3.0 | 2.3.0 |
| orjson | 3.10.7 | 3.10.9 |
| packaging | 24.1 | 24.1 |
| pandas | 2.2.3 | 2.2.3 |
| patsy | (missing) | 0.5.6 |
| pillow | 11.0.0 | 11.0.0 |
| platformdirs | 4.3.6 | 4.3.6 |
| polars | 1.9.0 | 1.10.0 |
| propcache | 0.2.0 | 0.2.0 |
| protobuf | 5.28.2 | 5.28.2 |
| psutil | 6.0.0 | 6.1.0 |
| pyDOE3 | 1.0.4 | 1.0.4 |
| pyparsing | 3.2.0 | 3.2.0 |
| pyrootutils | 1.0.4 | 1.0.4 |
| python-dateutil | 2.9.0.post0 | 2.9.0.post0 |
| python-dotenv | 1.0.1 | 1.0.1 |
| pytorch-lightning | 2.4.0 | 2.4.0 |
| pytz | 2024.2 | 2024.2 |
| requests | 2.32.3 | 2.32.3 |
| rich | 13.9.2 | 13.9.2 |
| rl4co | 0.5.0 | 0.5.0 |
| robust-downloader | 0.0.2 | 0.0.2 |
| scipy | 1.14.1 | 1.14.1 |
| seaborn | 0.13.2 | 0.13.2 |
| sentry-sdk | 2.17.0 | 2.17.0 |
| setproctitle | 1.3.3 | 1.3.3 |
| setuptools | 75.1.0 | 75.1.0 |
| six | 1.16.0 | 1.16.0 |
| smmap | 5.0.1 | 5.0.1 |
| statsmodels | 0.14.4 | 0.14.4 |
| sympy | 1.13.3 | 1.13.1 |
| tensorboard | 2.18.0 | 2.18.0 |
| tensorboard-data-server | 0.7.2 | 0.7.2 |
| tensordict | 0.5.0 | 0.5.0 |
| torch | 2.4.1 | 2.5.0 |
| torchmetrics | 1.4.3 | 1.5.0 |
| torchrl | 0.5.0 | 0.5.0 |
| tqdm | 4.66.5 | 4.66.5 |
| triton | 3.0.0 | 3.1.0 |
| typing_extensions | 4.12.2 | 4.12.2 |
| tzdata | 2024.2 | 2024.2 |
| urllib3 | 2.2.3 | 2.2.3 |
| wandb | 0.18.3 | 0.18.5 |
| Werkzeug | 3.0.4 | 3.0.4 |
| wheel | 0.44.0 | 0.44.0 |
| yarl | 1.15.4 | 1.15.5 |
System info
NVIDIA L40 system pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22
Reason and Possible fixes
No idea as to the reason. A temporary fix could be to lock the PyTorch version required by RL4CO to PyTorch 2.4.1.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have provided a minimal working example to reproduce the bug (required)