|
1 | 1 | #pragma once |
2 | 2 |
|
3 | 3 | #include "DevOPs.h" |
| 4 | +#include "dbl/Common.h" |
| 5 | +#include "dil/dil.hpp" |
4 | 6 | #include "torch_ipex/csrc/aten_ipex_bridge.h" |
5 | 7 | #include "torch_ipex/csrc/utils.h" |
6 | 8 | #include <ATen/Tensor.h> |
@@ -150,9 +152,11 @@ class NewMaxPool2dOp : public torch::autograd::Function<NewMaxPool2dOp> { |
150 | 152 | try { |
151 | 153 | if (torch_ipex::check_auto_dnnl() && |
152 | 154 | input.device().type() == c10::DeviceType::DPCPP) { |
| 155 | + auto src_dil_type = torch_ipex::cpu::dbl::comm::try_gen_dil_tensor(input).get_data_type(); |
| 156 | + auto input_temp = (src_dil_type == dil::data_type::u8 || src_dil_type == dil::data_type::s8) ? input : input.contiguous(); |
| 157 | + |
153 | 158 | at::Tensor output = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling( |
154 | | - input.is_contiguous() ? input : input.contiguous(), kernel_size, |
155 | | - stride, padding, dilation, ceil_mode); |
| 159 | + input_temp, kernel_size, stride, padding, dilation, ceil_mode); |
156 | 160 | return std::tuple<at::Tensor, at::Tensor>(output, output); |
157 | 161 | } |
158 | 162 | } catch (std::exception &e) { |
@@ -368,10 +372,10 @@ class NewApaptiveAvgPoolingOp |
368 | 372 | public: |
369 | 373 | static at::Tensor _forward(at::Tensor input, at::IntArrayRef output_size) { |
370 | 374 | try { |
371 | | - if (torch_ipex::check_auto_dnnl() && |
372 | | - input.device().type() == c10::DeviceType::DPCPP) { |
373 | | - return torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d( |
374 | | - input.is_contiguous() ? input : input.contiguous(), output_size); |
| 375 | + if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) { |
| 376 | + auto src_dil_type = torch_ipex::cpu::dbl::comm::try_gen_dil_tensor(input).get_data_type(); |
| 377 | + auto input_temp = (src_dil_type == dil::data_type::u8 || src_dil_type == dil::data_type::s8) ? input : input.contiguous(); |
| 378 | + return torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d(input_temp, output_size); |
375 | 379 | } |
376 | 380 | } catch (std::exception &e) { |
377 | 381 | #if defined(_DEBUG) |
|
0 commit comments