Skip to content

Commit 2042d21

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS] Migrate round unary op to Metal (pytorch#161712)
And actually use the right function, as [`torch.round`](https://docs.pytorch.org/docs/stable/generated/torch.round.html) doesn't use `std::round`, but rather `std::rint`, which can be easily seen by running something like ```python import torch print(torch.arange(-3., 3., step=.5, device='mps').round()) print(torch.arange(-3., 3., step=.5, device='mps').cpu().round()) ``` Before this change it printed ``` tensor([-3., -3., -2., -2., -1., -1., 0., 1., 1., 2., 2., 3.], device='mps:0') tensor([-3., -2., -2., -2., -1., -0., 0., 0., 1., 2., 2., 2.]) ``` But after this change results match Pull Request resolved: pytorch#161712 Approved by: https://github.com/dcci
1 parent 4fd761f commit 2042d21

File tree

5 files changed

+21
-7
lines changed

5 files changed

+21
-7
lines changed

aten/src/ATen/native/mps/kernels/UnaryKernel.metal

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,17 @@ struct round_decimals_functor {
503503
}
504504
};
505505

506+
struct round_functor {
507+
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
508+
inline T operator()(const T x) {
509+
return static_cast<T>(rint(float(x)));
510+
}
511+
template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
512+
inline T operator()(const T x) {
513+
return x;
514+
}
515+
};
516+
506517
DEFINE_UNARY_FLOATING_FUNCTOR(erf);
507518
DEFINE_UNARY_FLOATING_FUNCTOR(erfc);
508519
DEFINE_UNARY_FLOATING_FUNCTOR(erfinv);
@@ -515,6 +526,13 @@ REGISTER_UNARY_OP(neg, char, char);
515526
REGISTER_UNARY_OP(neg, uchar, uchar);
516527
REGISTER_UNARY_OP(neg, float, float);
517528
REGISTER_UNARY_OP(neg, half, half);
529+
REGISTER_UNARY_OP(round, int, int);
530+
REGISTER_UNARY_OP(round, long, long);
531+
REGISTER_UNARY_OP(round, short, short);
532+
REGISTER_UNARY_OP(round, char, char);
533+
REGISTER_UNARY_OP(round, uchar, uchar);
534+
REGISTER_UNARY_OP(round, float, float);
535+
REGISTER_UNARY_OP(round, half, half);
518536

519537
REGISTER_UNARY_OP(bitwise_not, int, int);
520538
REGISTER_UNARY_OP(bitwise_not, long, long);
@@ -558,6 +576,7 @@ REGISTER_UNARY_OP(abs, half, half);
558576

559577
INSTANTIATE_UNARY_KERNELS2(bfloat, bfloat);
560578
REGISTER_UNARY_OP(neg, bfloat, bfloat);
579+
REGISTER_UNARY_OP(round, bfloat, bfloat);
561580
REGISTER_UNARY_OP(abs, bfloat, bfloat);
562581
INSTANTIATE_UNARY_KERNELS2(half, half);
563582
INSTANTIATE_UNARY_KERNELS2(float, float);

aten/src/ATen/native/mps/operations/UnaryKernel.mm

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ static void round_decimals_kernel(TensorIteratorBase& iter, int64_t decimals) {
5050
REGISTER_UNARY_TI_DISPATCH(log);
5151
REGISTER_UNARY_TI_DISPATCH(log1p);
5252
REGISTER_UNARY_TI_DISPATCH(bitwise_not);
53+
REGISTER_UNARY_TI_DISPATCH(round);
5354
REGISTER_UNARY_TI_DISPATCH(sigmoid);
5455
REGISTER_DISPATCH(round_decimals_stub, round_decimals_kernel);
5556
} // namespace at::native

aten/src/ATen/native/mps/operations/UnaryOps.mm

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ static void unary_op(const Tensor& self,
184184

185185
REGISTER_MPS_UNARY_STUB(ceil, ceil);
186186
REGISTER_MPS_UNARY_STUB(floor, floor);
187-
REGISTER_MPS_UNARY_STUB(round, round);
188187
REGISTER_MPS_UNARY_STUB(trunc, truncate);
189188

190189
#define CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(func_out, func_stub) \

torch/_inductor/codegen/mps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def randint64(
366366

367367
@staticmethod
368368
def round(x: CSEVariable) -> str:
369-
return f"metal::round({x})"
369+
return f"metal::rint({x})"
370370

371371
@staticmethod
372372
def pow(a: CSEVariable, b: CSEVariable) -> str:

torch/testing/_internal/common_mps.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -439,9 +439,6 @@ def mps_ops_modifier(
439439
torch.uint8,
440440
torch.int8,
441441
],
442-
# round not working properly for float16 and bfloat16
443-
"round": [torch.float16, torch.bfloat16],
444-
"rounddecimals_0": [torch.bfloat16],
445442
}
446443

447444
if MACOS_VERSION < 15.0:
@@ -725,8 +722,6 @@ def mps_ops_grad_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]:
725722
"signal.windows.kaiser": [torch.float32],
726723
"signal.windows.nuttall": [torch.float32],
727724
"eye": [torch.float16, torch.float32],
728-
# round not working properly for float16
729-
"round": [torch.float16],
730725
# topk fails with duplicate indices
731726
"topk": [torch.float16],
732727
}

0 commit comments

Comments
 (0)