66#include < ATen/Tensor.h>
77#include < torch/script.h>
88#include < c10/util/Optional.h>
9+ #include " torch_ipex/csrc/aten_ipex_bridge.h"
910#include " torch_ipex/csrc/utils.h"
1011#include " DevOPs.h"
1112
@@ -68,17 +69,29 @@ class NewMaxPoolingOp : public torch::autograd::Function<NewMaxPoolingOp> {
6869 ctx->saved_data [" dilation" ] = dilation;
6970 ctx->saved_data [" ceil_mode" ] = ceil_mode;
7071
71- if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
72- at::Tensor output = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling (input.is_contiguous () ? input : input.contiguous (), kernel_size, stride,
73- padding, dilation, ceil_mode);
74- ctx->save_for_backward ({input, output});
75- return output;
72+ try {
73+ if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
74+ at::Tensor output = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling (input.is_contiguous () ? input : input.contiguous (), kernel_size, stride,
75+ padding, dilation, ceil_mode);
76+ ctx->save_for_backward ({input, output});
77+ return output;
78+ }
79+ } catch (std::exception& e) {
80+ #if defined(_DEBUG)
81+ TORCH_WARN (e.what ());
82+ #endif
83+ }
84+ at::Tensor output, indices;
85+ if (input.device ().type () == c10::DeviceType::DPCPP) {
86+ auto && _ipex_input = torch_ipex::bridge::shallowFallbackToCPUTensor (input);
87+ auto && _ipex_result = at::max_pool2d_with_indices (_ipex_input, kernel_size, stride, padding, dilation, ceil_mode);
88+ static_cast <void >(_ipex_result);
89+ std::tie (output, indices) = std::tuple<at::Tensor,at::Tensor>(torch_ipex::bridge::shallowUpgradeToDPCPPTensor (std::get<0 >(_ipex_result)), torch_ipex::bridge::shallowUpgradeToDPCPPTensor (std::get<1 >(_ipex_result)));
7690 } else {
77- at::Tensor output, indices;
7891 std::tie (output, indices) = at::max_pool2d_with_indices (input, kernel_size, stride, padding, dilation, ceil_mode);
79- ctx->save_for_backward ({input, indices});
80- return output;
8192 }
93+ ctx->save_for_backward ({input, indices});
94+ return output;
8295 }
8396
8497 static torch::autograd::tensor_list backward (
@@ -97,9 +110,26 @@ class NewMaxPoolingOp : public torch::autograd::Function<NewMaxPoolingOp> {
97110 std::vector<int64_t > dilation = ctx->saved_data [" dilation" ].toIntVector ();
98111 bool ceil_mode = ctx->saved_data [" ceil_mode" ].toBool ();
99112
100- if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
101- grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling_backward (
102- grad_output.is_contiguous () ? grad_output : grad_output.contiguous (), indices.is_contiguous () ? indices : indices.contiguous (), input.is_contiguous () ? input : input.contiguous (), kernel_size, stride, padding, dilation, ceil_mode);
113+
114+ try {
115+ if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
116+ grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling_backward (
117+ grad_output.is_contiguous () ? grad_output : grad_output.contiguous (), indices.is_contiguous () ? indices : indices.contiguous (), input.is_contiguous () ? input : input.contiguous (), kernel_size, stride, padding, dilation, ceil_mode);
118+ return {grad_input, at::Tensor (), at::Tensor (), at::Tensor (), at::Tensor (), at::Tensor ()};
119+ }
120+ } catch (std::exception& e) {
121+ #if defined(_DEBUG)
122+ TORCH_WARN (e.what ());
123+ #endif
124+ }
125+ if (input.device ().type () == c10::DeviceType::DPCPP) {
126+ auto && _ipex_grad_output = torch_ipex::bridge::shallowFallbackToCPUTensor (grad_output);
127+ auto && _ipex_input = torch_ipex::bridge::shallowFallbackToCPUTensor (input);
128+ auto && _ipex_indices = torch_ipex::bridge::shallowFallbackToCPUTensor (indices);
129+ auto && _ipex_grad_input = at::max_pool2d_with_indices_backward (_ipex_grad_output, _ipex_input, kernel_size,
130+ stride, padding, dilation, ceil_mode, _ipex_indices);
131+ static_cast <void >(_ipex_grad_input);
132+ grad_input = torch_ipex::bridge::shallowUpgradeToDPCPPTensor (_ipex_grad_input);
103133 } else {
104134 grad_input = at::max_pool2d_with_indices_backward (grad_output, input, kernel_size,
105135 stride, padding, dilation, ceil_mode, indices);
@@ -116,13 +146,23 @@ class NewApaptiveAvgPoolingOp : public torch::autograd::Function<NewApaptiveAvgP
116146 at::IntArrayRef output_size) {
117147 ctx->save_for_backward ({input});
118148
119- at::Tensor output;
120- if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
121- output = torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d (input.is_contiguous () ? input : input.contiguous (), output_size);
149+ try {
150+ if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
151+ return torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d (input.is_contiguous () ? input : input.contiguous (), output_size);
152+ }
153+ } catch (std::exception& e) {
154+ #if defined(_DEBUG)
155+ TORCH_WARN (e.what ());
156+ #endif
157+ }
158+ if (input.device ().type () == c10::DeviceType::DPCPP) {
159+ auto && _ipex_input = torch_ipex::bridge::shallowFallbackToCPUTensor (input);
160+ auto && _ipex_result = at::_adaptive_avg_pool2d (_ipex_input, output_size);
161+ static_cast <void >(_ipex_result); // Avoid warnings in case not used
162+ return torch_ipex::bridge::shallowUpgradeToDPCPPTensor (_ipex_result);
122163 } else {
123- output = at::_adaptive_avg_pool2d (input, output_size);
164+ return at::_adaptive_avg_pool2d (input, output_size);
124165 }
125- return output;
126166 }
127167
128168 static torch::autograd::tensor_list backward (
@@ -134,8 +174,22 @@ class NewApaptiveAvgPoolingOp : public torch::autograd::Function<NewApaptiveAvgP
134174 at::Tensor grad_output = grad_outputs[0 ];
135175 at::Tensor grad_input;
136176
137- if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
138- grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d_backward (grad_output.is_contiguous () ? grad_output : grad_output.contiguous (), input.is_contiguous () ? input : input.contiguous ());
177+ try {
178+ if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
179+ grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d_backward (grad_output.is_contiguous () ? grad_output : grad_output.contiguous (), input.is_contiguous () ? input : input.contiguous ());
180+ return {grad_input, at::Tensor ()};
181+ }
182+ } catch (std::exception& e) {
183+ #if defined(_DEBUG)
184+ TORCH_WARN (e.what ());
185+ #endif
186+ }
187+ if (input.device ().type () == c10::DeviceType::DPCPP) {
188+ auto && _ipex_grad_output = torch_ipex::bridge::shallowFallbackToCPUTensor (grad_output);
189+ auto && _ipex_input = torch_ipex::bridge::shallowFallbackToCPUTensor (input);
190+ auto && _ipex_result = at::_adaptive_avg_pool2d_backward (_ipex_grad_output, _ipex_input);
191+ static_cast <void >(_ipex_result); // Avoid warnings in case not used
192+ grad_input = torch_ipex::bridge::shallowUpgradeToDPCPPTensor (_ipex_result);
139193 } else {
140194 grad_input = at::_adaptive_avg_pool2d_backward (grad_output, input);
141195 }
0 commit comments