1616from torch .testing ._internal .common_device_type import (
1717 dtypes ,
1818 instantiate_device_type_tests ,
19+ largeTensorTest ,
1920 onlyCUDA ,
2021 OpDTypes ,
2122 ops ,
@@ -1358,22 +1359,13 @@ def test_foreach_copy_with_multi_dtypes(self, device, dtype, op):
13581359 # check (a) multi_tensor_apply is called and (b) numerical parity with for-loop and Tensor.copy_
13591360 foreach_copy_ = ForeachFuncWrapper (op .inplace_variant )
13601361
1361- tested_large_input = False
1362-
13631362 for sample in op .sample_inputs (
13641363 device , dtype , noncontiguous = False , allow_higher_dtype_scalars = True
13651364 ):
13661365 for src_dtype in floating_types_and (torch .half , torch .bfloat16 ):
13671366 if src_dtype == dtype :
13681367 continue
13691368 self_tensors = [t .clone () for t in sample .input ]
1370- if not tested_large_input :
1371- # see https://github.com/pytorch/pytorch/issues/156261
1372- self_tensors .append (
1373- torch .empty (2 ** 31 + 1 , device = device , dtype = dtype )
1374- )
1375- tested_large_input = True
1376-
13771369 src_tensors = [t .to (src_dtype ) for t in self_tensors ]
13781370 out = foreach_copy_ (
13791371 (self_tensors , src_tensors ), is_cuda = True , expect_fastpath = True
@@ -1385,6 +1377,17 @@ def test_foreach_copy_with_multi_dtypes(self, device, dtype, op):
13851377 for t , ref_t in zip (out , ref_out ):
13861378 self .assertTrue (torch .equal (t , ref_t ))
13871379
1380+ @onlyCUDA
1381+ @largeTensorTest ("40GB" , device = "cuda" )
1382+ def test_foreach_copy_with_multi_dtypes_large_input (self ):
1383+ # see https://github.com/pytorch/pytorch/issues/156261
1384+ self_tensor = torch .empty (2 ** 31 + 1 , device = "cuda" , dtype = torch .float32 )
1385+ src_tensor = torch .ones (2 ** 31 + 1 , device = "cuda" , dtype = torch .bfloat16 )
1386+
1387+ torch ._foreach_copy_ ([self_tensor ], [src_tensor ])
1388+ ref_out = torch .empty_like (self_tensor ).copy_ (src_tensor )
1389+ self .assertEqual (self_tensor , ref_out )
1390+
13881391 @requires_cuda
13891392 @ops (filter (lambda op : op .name == "_foreach_copy" , foreach_binary_op_db ))
13901393 def test_foreach_copy_with_different_device_inputs (self , device , dtype , op ):
0 commit comments