Skip to content

Commit 3e475c0

Browse files
yucai-inteltoyxu
andauthored
Add aten::fractional_max_pool2d, fractional_max_pool3d (forward and backward) (#880)
- [x] fractional_max_pool2d - [x] fractional_max_pool2d_backward - [x] fractional_max_pool3d - [x] fractional_max_pool3d_backward --------- Co-authored-by: Yutao Xu <[email protected]>
1 parent 0cd3091 commit 3e475c0

10 files changed

+827
-4
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#include <ATen/core/Tensor.h>
2+
#include <ATen/core/op_registration/adaption.h>
3+
#include <ATen/native/cpu/mixed_data_type.h>
4+
#include <ATen/native/xpu/sycl/FractionalMaxPool2dKernels.h>
5+
6+
#include <xpu/ATen/ops/fractional_max_pool2d_backward_native.h>
7+
#include <xpu/ATen/ops/fractional_max_pool2d_native.h>
8+
9+
namespace at::native {
10+
11+
TORCH_IMPL_FUNC(fractional_max_pool2d_out_xpu)
12+
(const Tensor& input,
13+
IntArrayRef pool_size,
14+
IntArrayRef output_size,
15+
const Tensor& randomSamples,
16+
const Tensor& output,
17+
const Tensor& indices) {
18+
xpu::fractional_max_pool2d_kernel(
19+
input, pool_size, output_size, randomSamples, output, indices);
20+
}
21+
22+
TORCH_IMPL_FUNC(fractional_max_pool2d_backward_xpu)
23+
(const Tensor& gradOutput,
24+
const Tensor& input,
25+
IntArrayRef pool_size /* unused */,
26+
IntArrayRef output_size,
27+
const Tensor& indices,
28+
const Tensor& gradInput) {
29+
xpu::fractional_max_pool2d_backward_kernel(
30+
gradOutput, input, pool_size, output_size, indices, gradInput);
31+
}
32+
33+
} // namespace at::native
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#include <ATen/core/Tensor.h>
2+
#include <ATen/core/op_registration/adaption.h>
3+
#include <ATen/native/cpu/mixed_data_type.h>
4+
#include <ATen/native/xpu/sycl/FractionalMaxPool3dKernels.h>
5+
#include <ATen/ops/empty.h>
6+
7+
#include <xpu/ATen/ops/fractional_max_pool3d_backward_native.h>
8+
#include <xpu/ATen/ops/fractional_max_pool3d_native.h>
9+
10+
namespace at::native {
11+
12+
TORCH_IMPL_FUNC(fractional_max_pool3d_out_xpu)
13+
(const Tensor& input,
14+
int64_t poolSizeT,
15+
int64_t poolSizeH,
16+
int64_t poolSizeW,
17+
int64_t outputT,
18+
int64_t outputH,
19+
int64_t outputW,
20+
const Tensor& randomSamples,
21+
int64_t numBatch,
22+
int64_t numPlanes,
23+
int64_t inputT,
24+
int64_t inputH,
25+
int64_t inputW,
26+
const Tensor& output,
27+
const Tensor& indices) {
28+
xpu::fractional_max_pool3d_kernel(
29+
input,
30+
poolSizeT,
31+
poolSizeH,
32+
poolSizeW,
33+
outputT,
34+
outputH,
35+
outputW,
36+
randomSamples,
37+
numBatch,
38+
numPlanes,
39+
inputT,
40+
inputH,
41+
inputW,
42+
output,
43+
indices);
44+
}
45+
46+
Tensor& fractional_max_pool3d_backward_out_xpu(
47+
const Tensor& grad_output,
48+
const Tensor& input,
49+
IntArrayRef pool_size,
50+
IntArrayRef output_size,
51+
const Tensor& indices,
52+
Tensor& grad_input) {
53+
globalContext().alertNotDeterministic(
54+
"fractional_max_pool3d_backward_out_xpu");
55+
xpu::fractional_max_pool3d_backward_kernel(
56+
grad_input, grad_output, input, output_size, indices);
57+
return grad_input;
58+
}
59+
60+
Tensor fractional_max_pool3d_backward_xpu(
61+
const Tensor& grad_output,
62+
const Tensor& input,
63+
IntArrayRef pool_size,
64+
IntArrayRef output_size,
65+
const Tensor& indices) {
66+
globalContext().alertNotDeterministic("fractional_max_pool3d_backward_xpu");
67+
Tensor grad_input = at::empty({0}, input.options());
68+
xpu::fractional_max_pool3d_backward_kernel(
69+
grad_input, grad_output, input, output_size, indices);
70+
return grad_input;
71+
}
72+
73+
} // namespace at::native

src/ATen/native/xpu/XPUFallback.template

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
164164
"_fft_c2r",
165165
"_fft_r2c",
166166
"_flash_attention_forward",
167-
"fractional_max_pool2d_backward.grad_input",
168-
"fractional_max_pool2d.output",
169-
"fractional_max_pool3d_backward",
170-
"fractional_max_pool3d.output",
171167
"frexp.Tensor_out",
172168
"_fused_moving_avg_obs_fq_helper",
173169
"geqrf",

0 commit comments

Comments
 (0)