|
| 1 | +#include <ATen/native/BinaryOps.h> |
1 | 2 | #include <ATen/native/ForeachUtils.h> |
2 | 3 | #include <ATen/ops/_foreach_add_native.h> |
3 | 4 | #include <ATen/ops/_foreach_addcdiv_native.h> |
|
8 | 9 | #include <ATen/ops/_foreach_lerp_native.h> |
9 | 10 | #include <ATen/ops/_foreach_mul_native.h> |
10 | 11 | #include <ATen/ops/_foreach_pow_native.h> |
| 12 | +#include <ATen/ops/_foreach_sub_native.h> |
11 | 13 |
|
12 | 14 | #include <ATen/native/xpu/sycl/ForeachBinaryOpScalarKernels.h> |
13 | 15 | #include <ATen/native/xpu/sycl/ForeachPointwiseOpScalarKernels.h> |
@@ -37,7 +39,33 @@ namespace native { |
37 | 39 | return xpu::FOREACH_BINARY_SCALAR_KERNEL_NAME(NAME)(tensors, scalar); \ |
38 | 40 | } |
39 | 41 |
|
| 42 | +// In the case of subtraction, we dont allow scalar to be boolean following the |
| 43 | +// torch.sub logic |
| 44 | +#define FOREACH_BINARY_OP_SCALAR_NO_BOOLEAN(NAME, DIV_OP) \ |
| 45 | + void foreach_tensor_##NAME##_scalar_kernel_xpu_( \ |
| 46 | + TensorList tensors, const Scalar& scalar) { \ |
| 47 | + check_foreach_api_restrictions(tensors); \ |
| 48 | + sub_check(tensors[0], scalar); \ |
| 49 | + if (!can_use_fast_route(tensors, scalar, DIV_OP)) { \ |
| 50 | + return foreach_tensor_##NAME##_scalar_kernel_slow_(tensors, scalar); \ |
| 51 | + } \ |
| 52 | + \ |
| 53 | + xpu::FOREACH_BINARY_SCALAR_INPLACE_KERNEL_NAME(NAME)(tensors, scalar); \ |
| 54 | + } \ |
| 55 | + \ |
| 56 | + std::vector<Tensor> foreach_tensor_##NAME##_scalar_kernel_xpu( \ |
| 57 | + TensorList tensors, const Scalar& scalar) { \ |
| 58 | + check_foreach_api_restrictions(tensors); \ |
| 59 | + sub_check(tensors[0], scalar); \ |
| 60 | + if (!can_use_fast_route(tensors, scalar, DIV_OP)) { \ |
| 61 | + return foreach_tensor_##NAME##_scalar_kernel_slow(tensors, scalar); \ |
| 62 | + } \ |
| 63 | + \ |
| 64 | + return xpu::FOREACH_BINARY_SCALAR_KERNEL_NAME(NAME)(tensors, scalar); \ |
| 65 | + } |
| 66 | + |
40 | 67 | FOREACH_BINARY_OP_SCALAR(add, /*div_op*/ false); |
| 68 | +FOREACH_BINARY_OP_SCALAR_NO_BOOLEAN(sub, /*div_op*/ false); |
41 | 69 | FOREACH_BINARY_OP_SCALAR(mul, /*div_op*/ false); |
42 | 70 | FOREACH_BINARY_OP_SCALAR(div, /*div_op*/ true); |
43 | 71 | FOREACH_BINARY_OP_SCALAR(clamp_max, /*div_op*/ true); |
|
0 commit comments