Skip to content

Commit 222b64a

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
CI test for pytorch 1.8
Summary: Pull Request resolved: #2641 Reviewed By: alexander-kirillov Differential Revision: D26563791 Pulled By: ppwwyyxx fbshipit-source-id: 694f2b8a94ad35b1fc5b1328162828a88fed4b95
1 parent c90ff5e commit 222b64a

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

.circleci/config.yml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,14 @@ install_linux_dep: &install_linux_dep
9999
command: |
100100
pip install --progress-bar off -U 'git+https://github.com/facebookresearch/fvcore'
101101
pip install --progress-bar off ninja opencv-python-headless pytest-xdist tensorboard pycocotools
102-
pip install --progress-bar off torch==$PYTORCH_VERSION torchvision==$TORCHVISION_VERSION -f https://download.pytorch.org/whl/torch_stable.html
102+
# install from pytorch's test wheels index to have access to RC wheels
103+
pip install --progress-bar off torch==$PYTORCH_VERSION -f https://download.pytorch.org/whl/test/torch_test.html
104+
if [[ "$TORCHVISION_VERSION" == "master" ]]; then
105+
pip install git+https://github.com/pytorch/vision.git
106+
else
107+
pip install --progress-bar off torchvision==$TORCHVISION_VERSION -f https://download.pytorch.org/whl/test/torch_test.html
108+
fi
109+
103110
python -c 'import torch; print("CUDA:", torch.cuda.is_available())'
104111
gcc --version
105112
@@ -222,4 +229,8 @@ workflows:
222229
name: linux_gpu_tests_pytorch1.7
223230
pytorch_version: '1.7+cu101'
224231
torchvision_version: '0.8.1+cu101'
232+
- linux_gpu_tests:
233+
name: linux_gpu_tests_pytorch1.8
234+
pytorch_version: '1.8+cu101'
235+
torchvision_version: 'master'
225236
- windows_cpu_build

tests/test_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_complex_model_loaded(self):
4242
# different tensor references
4343
self.assertFalse(id(loaded) == id(stored))
4444
# same content
45-
self.assertTrue(loaded.equal(stored))
45+
self.assertTrue(loaded.to(stored).equal(stored))
4646

4747

4848
if __name__ == "__main__":

0 commit comments

Comments
 (0)