Skip to content

Commit 497bb90

Browse files
committed
update test_torch.py and align with common_utils.py
1 parent 8ce3682 commit 497bb90

File tree

1 file changed

+6
-34
lines changed

1 file changed

+6
-34
lines changed

tests/cpu/test_torch.py

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@
7575
_compare_trilu_indices
7676
from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
7777
TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \
78-
IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, do_test_dtypes, do_test_empty_full, \
79-
IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, \
78+
IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, do_test_dtypes, do_test_empty_full, \
79+
IS_SANDCASTLE, load_tests, slowTest, \
8080
skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf
8181
from multiprocessing.reduction import ForkingPickler
8282
from common_device_type import instantiate_device_type_tests, \
@@ -1597,9 +1597,6 @@ def _test_multinomial_invalid_probs(probs):
15971597
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
15981598
don't support multiprocessing with spawn start method")
15991599
@unittest.skipIf(IS_WINDOWS, 'FIXME: CUDA OOM error on Windows')
1600-
@unittest.skipIf(not PY3,
1601-
"spawn start method is not supported in Python 2, \
1602-
but we need it for for testing failure case for CPU RNG on Windows")
16031600
def test_multinomial_invalid_probs(self):
16041601
test_method = _TestTorchMixin._test_multinomial_invalid_probs
16051602
self._spawn_method(test_method, torch.Tensor([1, -1, 1]))
@@ -4002,13 +3999,10 @@ def test_serialization(self):
40023999
buf = io.BytesIO(serialized)
40034000
utf8_bytes = b'\xc5\xbc\xc4\x85\xc4\x85\xc3\xb3\xc5\xbc\xc4\x85\xc5\xbc'
40044001
utf8_str = utf8_bytes.decode('utf-8')
4005-
if PY3:
4006-
loaded_utf8 = torch.load(buf, encoding='utf-8')
4007-
self.assertEqual(loaded_utf8, [utf8_str, torch.zeros(1, dtype=torch.float), 2])
4008-
buf.seek(0)
4009-
loaded_bytes = torch.load(buf, encoding='bytes')
4010-
else:
4011-
loaded_bytes = torch.load(buf)
4002+
loaded_utf8 = torch.load(buf, encoding='utf-8')
4003+
self.assertEqual(loaded_utf8, [utf8_str, torch.zeros(1, dtype=torch.float), 2])
4004+
buf.seek(0)
4005+
loaded_bytes = torch.load(buf, encoding='bytes')
40124006
self.assertEqual(loaded_bytes, [utf8_bytes, torch.zeros(1, dtype=torch.float), 2])
40134007

40144008
def test_serialization_filelike(self):
@@ -4292,7 +4286,6 @@ def check_map_locations(map_locations, tensor_class, intended_device):
42924286
)
42934287

42944288
@unittest.skipIf(torch.cuda.is_available(), "Testing torch.load on CPU-only machine")
4295-
@unittest.skipIf(not PY3, "Test tensors were serialized using python 3")
42964289
def test_load_nonexistent_device(self):
42974290
# Setup: create a serialized file object with a 'cuda:0' restore location
42984291
# The following was generated by saving a torch.randn(2, device='cuda') tensor.
@@ -11025,27 +11018,6 @@ def test_nonzero_non_diff(self, device):
1102511018
nz = x.nonzero()
1102611019
self.assertFalse(nz.requires_grad)
1102711020

11028-
def test_pdist_norm(self, device):
11029-
def test_pdist_single(shape, device, p, dtype, trans):
11030-
x = torch.randn(shape, dtype=dtype, device=device)
11031-
if trans:
11032-
x.transpose_(-2, -1)
11033-
actual = torch.pdist(x, p=p)
11034-
expected = brute_pdist(x, p=p)
11035-
self.assertEqual(expected.shape, actual.shape)
11036-
self.assertTrue(torch.allclose(expected, actual))
11037-
11038-
for shape in [(4, 5), (3, 2), (2, 1)]:
11039-
for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
11040-
for trans in [False, True]:
11041-
for dtype in [torch.float32, torch.float64]:
11042-
test_pdist_single(shape, device, p, dtype, trans)
11043-
11044-
# do a simplified comparison with big inputs, see:
11045-
# https://github.com/pytorch/pytorch/issues/15511
11046-
for dtype in [torch.float32, torch.float64]:
11047-
test_pdist_single((1000, 2), device, 2, dtype, False)
11048-
1104911021
def test_atan2(self, device):
1105011022
def _test_atan2_with_size(size, device):
1105111023
a = torch.rand(size=size, device=device, dtype=torch.double)

0 commit comments

Comments
 (0)