Skip to content

Commit 08af311

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS] Fix type promotion for torch.floor_divide (pytorch#149233)
And delete some duplicating glue code by relying on the stub After this change `torch.arange(10, device = 'mps') // torch.arange(10., device='mps')` will return tensor of floats, which is a common dtype for float + integral operation, rather than tensor of ints Checked by `test_div2` inductor testing Pull Request resolved: pytorch#149233 Approved by: https://github.com/atalman ghstack dependencies: pytorch#149216
1 parent eb7bf42 commit 08af311

File tree

3 files changed

+8
-20
lines changed

3 files changed

+8
-20
lines changed

aten/src/ATen/native/mps/operations/BinaryOps.mm

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#include <ATen/ops/atan2_native.h>
1515
#include <ATen/ops/div_native.h>
1616
#include <ATen/ops/eq_native.h>
17-
#include <ATen/ops/floor_divide_native.h>
1817
#include <ATen/ops/fmod_native.h>
1918
#include <ATen/ops/ge_native.h>
2019
#include <ATen/ops/gt_native.h>
@@ -447,19 +446,8 @@ static void add_sub_lerp_template(const Tensor& self,
447446
}
448447
}
449448

450-
Tensor& floor_divide_out_mps(const Tensor& self, const Tensor& other, Tensor& result) {
451-
mps::div_mode_template(self, other, "floor", result, "floor_divide_out");
452-
return result;
453-
}
454-
455-
Tensor floor_divide_mps(const Tensor& self, const Tensor& other) {
456-
Tensor output = at::empty_like(self);
457-
mps::div_mode_template(self, other, "floor", output, "floor_divide");
458-
return output;
459-
}
460-
461-
Tensor& floor_divide_mps_(Tensor& self, const Tensor& other) {
462-
return floor_divide_out_mps(self, other, self);
449+
static void div_floor_kernel_mps(TensorIteratorBase& iter) {
450+
mps::div_mode_template(iter.input(0), iter.input(1), "floor", iter.output(0), "floor_divide_out");
463451
}
464452

465453
TORCH_IMPL_FUNC(remainder_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
@@ -538,4 +526,6 @@ Tensor floor_divide_mps(const Tensor& self, const Tensor& other) {
538526
TORCH_IMPL_FUNC(lerp_Scalar_mps)(const Tensor& self, const Tensor& end, const Scalar& weight, const Tensor& out) {
539527
mps::add_sub_lerp_template(self, end, weight, out, "lerp");
540528
}
529+
530+
REGISTER_DISPATCH(div_floor_stub, &div_floor_kernel_mps);
541531
} // namespace at::native

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2749,23 +2749,20 @@
27492749
device_check: NoCheck # TensorIterator
27502750
variants: function, method
27512751
dispatch:
2752-
CPU, CUDA: floor_divide
2753-
MPS: floor_divide_mps
2752+
CPU, CUDA, MPS: floor_divide
27542753
SparseCPU, SparseCUDA: floor_divide_sparse
27552754

27562755
- func: floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
27572756
device_check: NoCheck # TensorIterator
27582757
variants: method
27592758
dispatch:
2760-
CPU, CUDA: floor_divide_
2761-
MPS: floor_divide_mps_
2759+
CPU, CUDA, MPS: floor_divide_
27622760
SparseCPU, SparseCUDA: floor_divide_sparse_
27632761

27642762
- func: floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
27652763
device_check: NoCheck # TensorIterator
27662764
dispatch:
2767-
CPU, CUDA: floor_divide_out
2768-
MPS: floor_divide_out_mps
2765+
CPU, CUDA, MPS: floor_divide_out
27692766
SparseCPU, SparseCUDA: floor_divide_out_sparse_zerodim
27702767

27712768
- func: floor_divide.Scalar(Tensor self, Scalar other) -> Tensor

test/inductor/test_mps_basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def fn(a):
190190
"test_cumsum_inf",
191191
"test_custom_op_2",
192192
"test_div1",
193+
"test_div2",
193194
"test_div3",
194195
"test_erfinv",
195196
"test_floordiv",

0 commit comments

Comments
 (0)