Skip to content

Commit 1a3161a

Browse files
authored
[Cherry-pick] Fix copysign + scalar correctness issue (pytorch#153098)
* [Testing] Add copysign from scalar regression test (pytorch#152997) But instead of adding it just for MPS backend, add it to OpInfo Fixes pytorch#152582 Pull Request resolved: pytorch#152997 Approved by: https://github.com/wdvr (cherry picked from commit 9919d6b) * Spiritual cherry-pick of 52cbcac * [CI] Skip test_copy_large_tensor on M2-15 runners (pytorch#150377) They have more than 12Gb memory, but may be running this test causes OOM in CI Pull Request resolved: pytorch#150377 Approved by: https://github.com/atalman
1 parent 27e9ca5 commit 1a3161a

File tree

3 files changed

+21
-0
lines changed

3 files changed

+21
-0
lines changed

aten/src/ATen/native/mps/operations/BinaryKernel.mm

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,24 @@
3434
static void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name, bool supports_dense = true) {
3535
TORCH_CHECK(iter.common_dtype() != at::kDouble, "float64 is not supported on MPS");
3636

37+
auto convert_double_scalar = [](Tensor& t) {
38+
if (t.dim() != 0) {
39+
return;
40+
}
41+
if (t.scalar_type() == kDouble) {
42+
t = t.to(kFloat);
43+
} else if (t.scalar_type() == kComplexDouble) {
44+
t = t.to(kComplexFloat);
45+
}
46+
};
47+
3748
Tensor input = iter.input(0);
3849
Tensor other = iter.input(1);
3950
Tensor out = iter.output();
4051

52+
convert_double_scalar(input);
53+
convert_double_scalar(other);
54+
4155
id<MTLDevice> device = MPSDevice::getInstance()->device();
4256
MPSStream* mpsStream = getCurrentMPSStream();
4357
const uint32_t nDim = iter.ndim();

test/test_mps.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7480,6 +7480,7 @@ def compare_mm(m, n, k, dtype=torch.float):
74807480

74817481
@unittest.skipIf(total_memory < 12_000_000_000, "Needs at least 12Gb RAM to run the test")
74827482
@unittest.skipIf(MACOS_VERSION < 14.0, "Can't allocate 4Gb tensor on MacOS 13")
7483+
@unittest.skipIf(IS_CI, "May be fixes https://github.com/pytorch/pytorch/issues/149999")
74837484
def test_copy_large(self):
74847485
""" Test that copy of 4Gb+ tensors works """
74857486
x = torch.ones((2**30 + 11,), dtype=torch.float32)

torch/testing/_internal/common_methods_invocations.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6165,6 +6165,11 @@ def _generate_correlation_inputs(device, dtype, requires_grad, **kwargs):
61656165
def sample_inputs_corrcoef(op_info, device, dtype, requires_grad, **kwargs):
61666166
return (SampleInput(t) for t in _generate_correlation_inputs(device, dtype, requires_grad))
61676167

6168+
def sample_inputs_copysign(op_info, device, dtype, requires_grad, **kwargs):
6169+
yield from sample_inputs_elementwise_binary(op_info, device, dtype, requires_grad, **kwargs)
6170+
if dtype.is_floating_point:
6171+
yield SampleInput(make_tensor(5, dtype=dtype, device=device, requires_grad=requires_grad), -3.14)
6172+
61686173

61696174
def sample_inputs_cov(op_info, device, dtype, requires_grad, **kwargs):
61706175
for t in _generate_correlation_inputs(device, dtype, requires_grad):
@@ -12882,6 +12887,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
1288212887
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='mps'),
1288312888
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),)),
1288412889
BinaryUfuncInfo('copysign',
12890+
sample_inputs_func=sample_inputs_copysign,
1288512891
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
1288612892
dtypesIfHpu=custom_types(torch.float32, torch.bfloat16),
1288712893
promotes_int_to_float=True,

0 commit comments

Comments
 (0)