Skip to content

Commit 1bee58f

Browse files
Fix cuda not being used on windows: update pytorch version (#6064)
These references were missed when upgrading from pytorch 1.x to 2.x in #6013 References found by running `grep -R '1\.13\.1' .` Install command chosen from the guide at https://pytorch.org/get-started/locally/
1 parent 6004556 commit 1bee58f

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

.yamato/pytest-gpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ pytest_gpu:
1111
python3 -m pip install pyyaml --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
1212
python3 -u -m ml-agents.tests.yamato.setup_venv
1313
python3 -m pip install --progress-bar=off -r test_requirements.txt --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
14-
python3 -m pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117 --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
14+
python3 -m pip install torch==2.2.1+cu121 torchvision==0.17.1+cu121 torchaudio==0.17.1 --index-url https://download.pytorch.org/whl/cu121
1515
if python -c "exec('import torch \nif not torch.cuda.is_available(): raise')" &> /dev/null; then
1616
echo 'all good'
1717
else

docs/Installation.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,12 @@ offer a dedicated [guide on Virtual Environments](Using-Virtual-Environment.md).
146146
#### (Windows) Installing PyTorch
147147

148148
On Windows, you'll have to install the PyTorch package separately prior to
149-
installing ML-Agents. Activate your virtual environment and run from the command line:
149+
installing ML-Agents in order to make sure the cuda-enabled version is used,
150+
rather than the CPU-only version. Activate your virtual environment and run from
151+
the command line:
150152

151153
```sh
152-
pip3 install torch~=1.13.1 -f https://download.pytorch.org/whl/torch_stable.html
154+
pip3 install torch~=2.2.1 --index-url https://download.pytorch.org/whl/cu121
153155
```
154156

155157
Note that on Windows, you may also need Microsoft's

0 commit comments

Comments
 (0)