Skip to content

Commit 1b09f44

Browse files
committed
make OneDNN pooling call inference path
1 parent 11b1d7c commit 1b09f44

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

torch_ipex/csrc/cpu/CustomOPs.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#pragma once
22

33
#include "DevOPs.h"
4+
#include "dbl/Common.h"
5+
#include "dil/dil.hpp"
46
#include "torch_ipex/csrc/aten_ipex_bridge.h"
57
#include "torch_ipex/csrc/utils.h"
68
#include <ATen/Tensor.h>
@@ -150,9 +152,11 @@ class NewMaxPool2dOp : public torch::autograd::Function<NewMaxPool2dOp> {
150152
try {
151153
if (torch_ipex::check_auto_dnnl() &&
152154
input.device().type() == c10::DeviceType::DPCPP) {
155+
auto src_dil_type = torch_ipex::cpu::dbl::comm::try_gen_dil_tensor(input).get_data_type();
156+
auto input_temp = (src_dil_type == dil::data_type::u8 || src_dil_type == dil::data_type::s8) ? input : input.contiguous();
157+
153158
at::Tensor output = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling(
154-
input.is_contiguous() ? input : input.contiguous(), kernel_size,
155-
stride, padding, dilation, ceil_mode);
159+
input_temp, kernel_size, stride, padding, dilation, ceil_mode);
156160
return std::tuple<at::Tensor, at::Tensor>(output, output);
157161
}
158162
} catch (std::exception &e) {
@@ -368,10 +372,10 @@ class NewApaptiveAvgPoolingOp
368372
public:
369373
static at::Tensor _forward(at::Tensor input, at::IntArrayRef output_size) {
370374
try {
371-
if (torch_ipex::check_auto_dnnl() &&
372-
input.device().type() == c10::DeviceType::DPCPP) {
373-
return torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d(
374-
input.is_contiguous() ? input : input.contiguous(), output_size);
375+
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
376+
auto src_dil_type = torch_ipex::cpu::dbl::comm::try_gen_dil_tensor(input).get_data_type();
377+
auto input_temp = (src_dil_type == dil::data_type::u8 || src_dil_type == dil::data_type::s8) ? input : input.contiguous();
378+
return torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d(input_temp, output_size);
375379
}
376380
} catch (std::exception &e) {
377381
#if defined(_DEBUG)

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,11 @@ at::Tensor AtenIpexCPUDev::dil_convolution_overrideable(const at::Tensor & input
264264
return AtenIpexCPUDev::dil_deconvolution(input.is_contiguous() ? input : input.contiguous(), weight.is_contiguous() ? weight : weight.contiguous(), bias.defined() && !bias.is_contiguous() ? bias.contiguous() : bias, padding, output_padding, stride, dilation, groups);
265265
} else {
266266
// for int8 path, input always acbd format which is non-contiguous, .contiguous() will reorder to fp32
267-
return AtenIpexCPUDev::dil_convolution(input, weight, bias, stride, padding, dilation, groups);
267+
auto src_dil_type = dbl::comm::try_gen_dil_tensor(input).get_data_type();
268+
auto input_temp = (src_dil_type == dil::data_type::u8 || src_dil_type == dil::data_type::s8) ? input : input.contiguous();
269+
auto weight_dil_type = dbl::comm::try_gen_dil_tensor(weight).get_data_type();
270+
auto weight_temp = weight_dil_type == dil::data_type::s8 ? weight : weight.contiguous();
271+
return AtenIpexCPUDev::dil_convolution(input_temp, weight_temp, bias, stride, padding, dilation, groups);
268272
}
269273
}
270274
}
@@ -788,7 +792,6 @@ at::Tensor AtenIpexCPUDev::dil_linear(
788792
b = dbl::comm::try_gen_dil_tensor(bias);
789793
}
790794

791-
auto output_scale = dbl::comm::get_int8_scale(/* uint8_used=false */);
792795
dil::tensor y = dbl::linear::linear_impl(x, w, b, output_scale);
793796

794797
auto input_size = self.sizes();

torch_ipex/csrc/cpu/dbl/Pool.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ at::Tensor _dil_pooling(
125125
}
126126

127127
dil::tensor y;
128+
dil::prop_kind aprop_kind = dil::prop_kind::forward;
129+
auto src_type = x.get_data_type();
130+
if (dil::data_type::s8 == src_type || dil::data_type::u8 == src_type) {
131+
aprop_kind = dil::prop_kind::forward_inference;
132+
}
128133
dil::pooling_forward::compute(
129134
x,
130135
{output_sizes.cbegin(), output_sizes.cend()},
@@ -134,7 +139,7 @@ at::Tensor _dil_pooling(
134139
{padding_vec_l.cbegin(), padding_vec_l.cend()},
135140
{padding_vec_r.cbegin(), padding_vec_r.cend()},
136141
algo,
137-
dil::prop_kind::forward);
142+
aprop_kind);
138143

139144
return gen_aten_tensor_by(std::move(y));
140145
}

0 commit comments

Comments
 (0)