Skip to content

Commit c76b235

Browse files
pytorchbotjaneyx99
andauthored
Move out super large one off foreach_copy test (pytorch#158880)
Move out super large one off foreach_copy test (pytorch#156876) Pull Request resolved: pytorch#156876 Approved by: https://github.com/albanD, https://github.com/jeffdaily (cherry picked from commit 50b2069) Co-authored-by: Jane Xu <[email protected]>
1 parent 20a0e22 commit c76b235

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

test/test_foreach.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch.testing._internal.common_device_type import (
1717
dtypes,
1818
instantiate_device_type_tests,
19+
largeTensorTest,
1920
onlyCUDA,
2021
OpDTypes,
2122
ops,
@@ -1358,22 +1359,13 @@ def test_foreach_copy_with_multi_dtypes(self, device, dtype, op):
13581359
# check (a) multi_tensor_apply is called and (b) numerical parity with for-loop and Tensor.copy_
13591360
foreach_copy_ = ForeachFuncWrapper(op.inplace_variant)
13601361

1361-
tested_large_input = False
1362-
13631362
for sample in op.sample_inputs(
13641363
device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True
13651364
):
13661365
for src_dtype in floating_types_and(torch.half, torch.bfloat16):
13671366
if src_dtype == dtype:
13681367
continue
13691368
self_tensors = [t.clone() for t in sample.input]
1370-
if not tested_large_input:
1371-
# see https://github.com/pytorch/pytorch/issues/156261
1372-
self_tensors.append(
1373-
torch.empty(2**31 + 1, device=device, dtype=dtype)
1374-
)
1375-
tested_large_input = True
1376-
13771369
src_tensors = [t.to(src_dtype) for t in self_tensors]
13781370
out = foreach_copy_(
13791371
(self_tensors, src_tensors), is_cuda=True, expect_fastpath=True
@@ -1385,6 +1377,17 @@ def test_foreach_copy_with_multi_dtypes(self, device, dtype, op):
13851377
for t, ref_t in zip(out, ref_out):
13861378
self.assertTrue(torch.equal(t, ref_t))
13871379

1380+
@onlyCUDA
1381+
@largeTensorTest("40GB", device="cuda")
1382+
def test_foreach_copy_with_multi_dtypes_large_input(self):
1383+
# see https://github.com/pytorch/pytorch/issues/156261
1384+
self_tensor = torch.empty(2**31 + 1, device="cuda", dtype=torch.float32)
1385+
src_tensor = torch.ones(2**31 + 1, device="cuda", dtype=torch.bfloat16)
1386+
1387+
torch._foreach_copy_([self_tensor], [src_tensor])
1388+
ref_out = torch.empty_like(self_tensor).copy_(src_tensor)
1389+
self.assertEqual(self_tensor, ref_out)
1390+
13881391
@requires_cuda
13891392
@ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
13901393
def test_foreach_copy_with_different_device_inputs(self, device, dtype, op):

0 commit comments

Comments
 (0)