Skip to content

Commit 5e29831

Browse files
cfgfungtoyxu
andauthored
Add aten::foreach_sub and its variants (#1034)
Add the following operators: - foreach_sub.List - foreach_sub.Scalar - foreach_sub.ScalarList - foreach_sub_.List - foreach_sub_.Scalar - foreach_sub_.ScalarList --------- Co-authored-by: Yutao Xu <[email protected]>
1 parent dfc3461 commit 5e29831

29 files changed

+200
-62
lines changed

src/ATen/native/xpu/AiryAi.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
#include <ATen/native/UnaryOps.h>
21
#include <ATen/native/DispatchStub.h>
32
#include <ATen/native/TensorIterator.h>
3+
#include <ATen/native/UnaryOps.h>
44
#include <ATen/native/xpu/sycl/AiryAiKernel.h>
55

66
namespace at {

src/ATen/native/xpu/BinaryOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
#include <ATen/native/xpu/sycl/ChebyshevPolynomialKernels.h>
1818
#include <ATen/native/xpu/sycl/CopysignKernel.h>
1919
#include <ATen/native/xpu/sycl/GcdLcmKernels.h>
20-
#include <ATen/native/xpu/sycl/IGammaKernel.h>
2120
#include <ATen/native/xpu/sycl/HermitePolynomialHKernel.h>
2221
#include <ATen/native/xpu/sycl/HermitePolynomialHeKernel.h>
22+
#include <ATen/native/xpu/sycl/IGammaKernel.h>
2323
#include <ATen/native/xpu/sycl/LaguerrePolynomialLKernel.h>
2424
#include <ATen/native/xpu/sycl/LegendrePolynomialPKernel.h>
2525
#include <ATen/native/xpu/sycl/LogAddExpKernels.h>

src/ATen/native/xpu/Embedding.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,5 @@ Tensor& embedding_renorm_xpu_(
3535
self, indices, max_norm, norm_type);
3636
}
3737

38-
3938
} // namespace native
4039
} // namespace at

src/ATen/native/xpu/ForeachOpList.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
#include <ATen/ops/_foreach_addcmul_native.h>
55
#include <ATen/ops/_foreach_clamp_max_native.h>
66
#include <ATen/ops/_foreach_clamp_min_native.h>
7+
#include <ATen/ops/_foreach_copy_native.h>
78
#include <ATen/ops/_foreach_div_native.h>
89
#include <ATen/ops/_foreach_lerp_native.h>
910
#include <ATen/ops/_foreach_mul_native.h>
10-
#include <ATen/ops/_foreach_clamp_min_native.h>
11-
#include <ATen/ops/_foreach_copy_native.h>
1211
#include <ATen/ops/_foreach_pow_native.h>
12+
#include <ATen/ops/_foreach_sub_native.h>
1313

1414
#include <ATen/native/xpu/sycl/ForeachBinaryOpListKernels.h>
15+
#include <ATen/native/xpu/sycl/ForeachCopyKernels.h>
1516
#include <ATen/native/xpu/sycl/ForeachPointwiseOpListKernels.h>
1617
#include <ATen/native/xpu/sycl/ForeachTernaryOpListKernels.h>
17-
#include <ATen/native/xpu/sycl/ForeachCopyKernels.h>
1818

1919
#include <ATen/ops/empty_like.h>
2020

@@ -68,6 +68,7 @@ namespace native {
6868
}
6969

7070
FOREACH_BINARY_OP_LIST_ALPHA(add);
71+
FOREACH_BINARY_OP_LIST_ALPHA(sub);
7172
FOREACH_BINARY_OP_LIST(mul, false);
7273
FOREACH_BINARY_OP_LIST(div, true);
7374
FOREACH_BINARY_OP_LIST(clamp_max, true);
@@ -154,12 +155,11 @@ void foreach_tensor_copy_list_kernel_xpu_(
154155
TensorList self,
155156
TensorList src,
156157
bool non_blocking) {
157-
check_foreach_api_restrictions(self, src);
158-
if (!can_use_fast_route(
159-
self, src, /* does_op_promote_integer_inputs_to_float */ false)) {
160-
return foreach_tensor_copy_list_kernel_slow_(
161-
self, src, non_blocking);
162-
}
158+
check_foreach_api_restrictions(self, src);
159+
if (!can_use_fast_route(
160+
self, src, /* does_op_promote_integer_inputs_to_float */ false)) {
161+
return foreach_tensor_copy_list_kernel_slow_(self, src, non_blocking);
162+
}
163163

164164
xpu::foreach_copy_list_kernel_(self, src);
165165

src/ATen/native/xpu/ForeachOpScalar.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <ATen/native/BinaryOps.h>
12
#include <ATen/native/ForeachUtils.h>
23
#include <ATen/ops/_foreach_add_native.h>
34
#include <ATen/ops/_foreach_addcdiv_native.h>
@@ -8,6 +9,7 @@
89
#include <ATen/ops/_foreach_lerp_native.h>
910
#include <ATen/ops/_foreach_mul_native.h>
1011
#include <ATen/ops/_foreach_pow_native.h>
12+
#include <ATen/ops/_foreach_sub_native.h>
1113

1214
#include <ATen/native/xpu/sycl/ForeachBinaryOpScalarKernels.h>
1315
#include <ATen/native/xpu/sycl/ForeachPointwiseOpScalarKernels.h>
@@ -37,7 +39,33 @@ namespace native {
3739
return xpu::FOREACH_BINARY_SCALAR_KERNEL_NAME(NAME)(tensors, scalar); \
3840
}
3941

42+
// In the case of subtraction, we dont allow scalar to be boolean following the
43+
// torch.sub logic
44+
#define FOREACH_BINARY_OP_SCALAR_NO_BOOLEAN(NAME, DIV_OP) \
45+
void foreach_tensor_##NAME##_scalar_kernel_xpu_( \
46+
TensorList tensors, const Scalar& scalar) { \
47+
check_foreach_api_restrictions(tensors); \
48+
sub_check(tensors[0], scalar); \
49+
if (!can_use_fast_route(tensors, scalar, DIV_OP)) { \
50+
return foreach_tensor_##NAME##_scalar_kernel_slow_(tensors, scalar); \
51+
} \
52+
\
53+
xpu::FOREACH_BINARY_SCALAR_INPLACE_KERNEL_NAME(NAME)(tensors, scalar); \
54+
} \
55+
\
56+
std::vector<Tensor> foreach_tensor_##NAME##_scalar_kernel_xpu( \
57+
TensorList tensors, const Scalar& scalar) { \
58+
check_foreach_api_restrictions(tensors); \
59+
sub_check(tensors[0], scalar); \
60+
if (!can_use_fast_route(tensors, scalar, DIV_OP)) { \
61+
return foreach_tensor_##NAME##_scalar_kernel_slow(tensors, scalar); \
62+
} \
63+
\
64+
return xpu::FOREACH_BINARY_SCALAR_KERNEL_NAME(NAME)(tensors, scalar); \
65+
}
66+
4067
FOREACH_BINARY_OP_SCALAR(add, /*div_op*/ false);
68+
FOREACH_BINARY_OP_SCALAR_NO_BOOLEAN(sub, /*div_op*/ false);
4169
FOREACH_BINARY_OP_SCALAR(mul, /*div_op*/ false);
4270
FOREACH_BINARY_OP_SCALAR(div, /*div_op*/ true);
4371
FOREACH_BINARY_OP_SCALAR(clamp_max, /*div_op*/ true);

src/ATen/native/xpu/ForeachOpScalarList.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <ATen/native/BinaryOps.h>
12
#include <ATen/native/ForeachUtils.h>
23
#include <ATen/ops/_foreach_add_native.h>
34
#include <ATen/ops/_foreach_addcdiv_native.h>
@@ -7,6 +8,7 @@
78
#include <ATen/ops/_foreach_div_native.h>
89
#include <ATen/ops/_foreach_mul_native.h>
910
#include <ATen/ops/_foreach_pow_native.h>
11+
#include <ATen/ops/_foreach_sub_native.h>
1012

1113
#include <ATen/native/xpu/sycl/ForeachBinaryOpScalarListKernels.h>
1214
#include <ATen/native/xpu/sycl/ForeachPointwiseOpScalarListKernels.h>
@@ -40,6 +42,39 @@ namespace native {
4042
return xpu::FOREACH_BINARY_SCALARLIST_KERNEL_NAME(NAME)(tensors, scalars); \
4143
}
4244

45+
// This does not use FOREACH_BINARY_OP_SCALARLIST because
46+
// In the case of subtraction, we dont allow scalar to be boolean following the
47+
// torch.sub logic
48+
void foreach_tensor_sub_scalarlist_kernel_xpu_(
49+
TensorList tensors,
50+
at::ArrayRef<Scalar> scalars) {
51+
check_foreach_api_restrictions(tensors, scalars);
52+
for (const auto i : c10::irange(tensors.size())) {
53+
sub_check(tensors[i], scalars[i]);
54+
}
55+
56+
if (!can_use_fast_route({tensors}, scalars, false)) {
57+
return foreach_tensor_sub_scalarlist_kernel_slow_(tensors, scalars);
58+
}
59+
60+
xpu::FOREACH_BINARY_SCALARLIST_INPLACE_KERNEL_NAME(sub)(tensors, scalars);
61+
}
62+
63+
std::vector<Tensor> foreach_tensor_sub_scalarlist_kernel_xpu(
64+
TensorList tensors,
65+
at::ArrayRef<Scalar> scalars) {
66+
check_foreach_api_restrictions(tensors, scalars);
67+
for (const auto i : c10::irange(tensors.size())) {
68+
sub_check(tensors[i], scalars[i]);
69+
}
70+
71+
if (!can_use_fast_route({tensors}, scalars, false)) {
72+
return foreach_tensor_sub_scalarlist_kernel_slow(tensors, scalars);
73+
}
74+
75+
return xpu::FOREACH_BINARY_SCALARLIST_KERNEL_NAME(sub)(tensors, scalars);
76+
}
77+
4378
FOREACH_BINARY_OP_SCALARLIST(add, /*div_op*/ false);
4479
FOREACH_BINARY_OP_SCALARLIST(mul, /*div_op*/ false);
4580
FOREACH_BINARY_OP_SCALARLIST(div, /*div_op*/ true);

src/ATen/native/xpu/UpSampleTrilinear3d.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ TORCH_IMPL_FUNC(upsample_trilinear3d_backward_out_xpu)
2828
std::optional<double> scales_h,
2929
std::optional<double> scales_w,
3030
const Tensor& grad_input) {
31-
globalContext().alertNotDeterministic("upsample_trilinear3d_backward_out_xpu");
31+
globalContext().alertNotDeterministic(
32+
"upsample_trilinear3d_backward_out_xpu");
3233
xpu::upsample_trilinear3d_backward_out_kernel(
3334
grad_input,
3435
grad_output,

src/ATen/native/xpu/sycl/ChebyshevPolynomialTKernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include <ATen/ATen.h>
22
#include <ATen/native/Math.h>
33
#include <ATen/native/TensorIterator.h>
4-
#include <ATen/native/xpu/sycl/Loops.h>
54
#include <ATen/native/xpu/sycl/ChebyshevPolynomialKernels.h>
5+
#include <ATen/native/xpu/sycl/Loops.h>
66

77
namespace at::native::xpu {
88

src/ATen/native/xpu/sycl/ChebyshevPolynomialUKernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include <ATen/ATen.h>
22
#include <ATen/native/Math.h>
33
#include <ATen/native/TensorIterator.h>
4-
#include <ATen/native/xpu/sycl/Loops.h>
54
#include <ATen/native/xpu/sycl/ChebyshevPolynomialKernels.h>
5+
#include <ATen/native/xpu/sycl/Loops.h>
66

77
namespace at::native::xpu {
88

src/ATen/native/xpu/sycl/ChebyshevPolynomialVKernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include <ATen/ATen.h>
22
#include <ATen/native/Math.h>
33
#include <ATen/native/TensorIterator.h>
4-
#include <ATen/native/xpu/sycl/Loops.h>
54
#include <ATen/native/xpu/sycl/ChebyshevPolynomialKernels.h>
5+
#include <ATen/native/xpu/sycl/Loops.h>
66

77
namespace at::native::xpu {
88

0 commit comments

Comments
 (0)