Skip to content

Commit 6a8c6ae

Browse files
committed
fix typo
1 parent 0d6933c commit 6a8c6ae

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

torch_ipex/csrc/cpu/CustomOPs.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ class NewMaxPool2dOp : public torch::autograd::Function<NewMaxPool2dOp> {
153153
if (torch_ipex::check_auto_dnnl() &&
154154
input.device().type() == c10::DeviceType::DPCPP) {
155155
auto src_dil_type = torch_ipex::cpu::dbl::comm::try_gen_dil_tensor(input).get_data_type();
156-
auto input_temp = (src_dil_type == dil::data_type::u8 || src_dil_type == dil::data_type::s8) ? input : input.contiguous();
156+
auto input_temp = (src_dil_type == dil::data_type::u8 || src_dil_type == dil::data_type::s8
157+
|| input.is_contiguous()) ? input : input.contiguous();
157158

158159
at::Tensor output = torch_ipex::cpu::AtenIpexCPUDev::dil_max_pooling(
159160
input_temp, kernel_size, stride, padding, dilation, ceil_mode);
@@ -374,7 +375,8 @@ class NewApaptiveAvgPoolingOp
374375
try {
375376
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
376377
auto src_dil_type = torch_ipex::cpu::dbl::comm::try_gen_dil_tensor(input).get_data_type();
377-
auto input_temp = (src_dil_type == dil::data_type::u8 || src_dil_type == dil::data_type::s8) ? input : input.contiguous();
378+
auto input_temp = (src_dil_type == dil::data_type::u8 || src_dil_type == dil::data_type::s8
379+
|| input.is_contiguous()) ? input : input.contiguous();
378380
return torch_ipex::cpu::AtenIpexCPUDev::dil_adaptive_avg_pool2d(input_temp, output_size);
379381
}
380382
} catch (std::exception &e) {

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,9 @@ at::Tensor AtenIpexCPUDev::dil_convolution_overrideable(const at::Tensor & input
265265
} else {
266266
// for int8 path, input always acbd format which is non-contiguous, .contiguous() will reorder to fp32
267267
auto src_dil_type = dbl::comm::try_gen_dil_tensor(input).get_data_type();
268-
auto input_temp = (src_dil_type == dil::data_type::u8 || src_dil_type == dil::data_type::s8) ? input : input.contiguous();
268+
auto input_temp = (src_dil_type == dil::data_type::u8 || src_dil_type == dil::data_type::s8 || input.is_contiguous()) ? input : input.contiguous();
269269
auto weight_dil_type = dbl::comm::try_gen_dil_tensor(weight).get_data_type();
270-
auto weight_temp = weight_dil_type == dil::data_type::s8 ? weight : weight.contiguous();
270+
auto weight_temp = (weight_dil_type == dil::data_type::s8 || weight.is_contiguous()) ? weight : weight.contiguous();
271271
return AtenIpexCPUDev::dil_convolution(input_temp, weight_temp, bias, stride, padding, dilation, groups);
272272
}
273273
}
@@ -962,7 +962,6 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexCPUDev::dil_native_batch_
962962
input_scales.push_back(scales[0]);
963963
output_scales.push_back(scales[1]);
964964
dbl::comm::reorder_to_int8_for_mix_prec(input, input_scales);
965-
//std::cout<<"print scale "<<scales[0]<<" "<<input.abs().max()<<std::endl;
966965
} else {
967966
dbl::comm::reorder_to_dtype(input, at::kFloat);
968967
}

0 commit comments

Comments
 (0)