16
16
from torch .testing ._internal .common_device_type import (
17
17
dtypes ,
18
18
instantiate_device_type_tests ,
19
+ largeTensorTest ,
19
20
onlyCUDA ,
20
21
OpDTypes ,
21
22
ops ,
@@ -1358,22 +1359,13 @@ def test_foreach_copy_with_multi_dtypes(self, device, dtype, op):
1358
1359
# check (a) multi_tensor_apply is called and (b) numerical parity with for-loop and Tensor.copy_
1359
1360
foreach_copy_ = ForeachFuncWrapper (op .inplace_variant )
1360
1361
1361
- tested_large_input = False
1362
-
1363
1362
for sample in op .sample_inputs (
1364
1363
device , dtype , noncontiguous = False , allow_higher_dtype_scalars = True
1365
1364
):
1366
1365
for src_dtype in floating_types_and (torch .half , torch .bfloat16 ):
1367
1366
if src_dtype == dtype :
1368
1367
continue
1369
1368
self_tensors = [t .clone () for t in sample .input ]
1370
- if not tested_large_input :
1371
- # see https://github.com/pytorch/pytorch/issues/156261
1372
- self_tensors .append (
1373
- torch .empty (2 ** 31 + 1 , device = device , dtype = dtype )
1374
- )
1375
- tested_large_input = True
1376
-
1377
1369
src_tensors = [t .to (src_dtype ) for t in self_tensors ]
1378
1370
out = foreach_copy_ (
1379
1371
(self_tensors , src_tensors ), is_cuda = True , expect_fastpath = True
@@ -1385,6 +1377,17 @@ def test_foreach_copy_with_multi_dtypes(self, device, dtype, op):
1385
1377
for t , ref_t in zip (out , ref_out ):
1386
1378
self .assertTrue (torch .equal (t , ref_t ))
1387
1379
1380
+ @onlyCUDA
1381
+ @largeTensorTest ("40GB" , device = "cuda" )
1382
+ def test_foreach_copy_with_multi_dtypes_large_input (self ):
1383
+ # see https://github.com/pytorch/pytorch/issues/156261
1384
+ self_tensor = torch .empty (2 ** 31 + 1 , device = "cuda" , dtype = torch .float32 )
1385
+ src_tensor = torch .ones (2 ** 31 + 1 , device = "cuda" , dtype = torch .bfloat16 )
1386
+
1387
+ torch ._foreach_copy_ ([self_tensor ], [src_tensor ])
1388
+ ref_out = torch .empty_like (self_tensor ).copy_ (src_tensor )
1389
+ self .assertEqual (self_tensor , ref_out )
1390
+
1388
1391
@requires_cuda
1389
1392
@ops (filter (lambda op : op .name == "_foreach_copy" , foreach_binary_op_db ))
1390
1393
def test_foreach_copy_with_different_device_inputs (self , device , dtype , op ):
0 commit comments