|
14 | 14 | #include "dbl/Common.h" |
15 | 15 | #include "dbl/Conv.h" |
16 | 16 | #include "dbl/Pool.h" |
| 17 | +#include "dbl/DNNLChecker.h" |
17 | 18 | #include "ShadeDataContext.h" |
18 | 19 |
|
19 | 20 | #include "dil/dil.hpp" |
@@ -173,19 +174,31 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> AtenIpexCPUDev::dil_convolution_bac |
173 | 174 |
|
174 | 175 | at::Tensor AtenIpexCPUDev::dil_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) { |
175 | 176 | DEBUG("AtenIpexCPUDev::convolution_overrideable\n"); |
176 | | - // NOTE: DO NOT always call contiguous. It may break lazy-reorder. Because contiguous will call reorder instantly. |
177 | | - if (check_auto_dnnl()) { |
178 | | - return dil_convolution( |
179 | | - input.is_contiguous() ? input : input.contiguous(), |
180 | | - weight.is_contiguous() ? weight : weight.contiguous(), |
181 | | - bias.defined() ? (bias.is_contiguous() ? bias :bias.contiguous()) : bias, |
182 | | - stride, |
183 | | - padding, |
184 | | - dilation, |
185 | | - groups); |
186 | | - } else { |
187 | | - return mkldnn_convolution(input, weight, bias, padding, stride, dilation, groups); |
| 177 | + |
| 178 | + try { |
| 179 | + if (check_auto_dnnl()) { |
| 180 | + std::vector<at::Tensor> dnnl_input_tensors; |
| 181 | + dnnl_input_tensors.push_back(input); |
| 182 | + dnnl_input_tensors.push_back(weight); |
| 183 | + dnnl_input_tensors.push_back(bias); |
| 184 | + if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors)) |
| 185 | + return AtenIpexCPUDev::dil_convolution(input.is_contiguous() ? input : input.contiguous(), weight.is_contiguous() ? weight : weight.contiguous(), bias.is_contiguous() ? bias : bias.contiguous(), stride, padding, dilation, groups); |
| 186 | + } |
| 187 | + } catch (std::exception& e) { |
| 188 | +#if defined(_DEBUG) |
| 189 | + TORCH_WARN(e.what()); |
| 190 | +#endif |
188 | 191 | } |
| 192 | + |
| 193 | + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.layout() == c10::kStrided); |
| 194 | + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weight.layout() == c10::kStrided); |
| 195 | + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(bias.layout() == c10::kStrided); |
| 196 | + auto&& _ipex_input = bridge::shallowFallbackToCPUTensor(input); |
| 197 | + auto&& _ipex_weight = bridge::shallowFallbackToCPUTensor(weight); |
| 198 | + auto&& _ipex_bias = bridge::shallowFallbackToCPUTensor(bias); |
| 199 | + auto&& _ipex_result = at::mkldnn_convolution(_ipex_input, _ipex_weight, _ipex_bias, padding, stride, dilation, groups); |
| 200 | + static_cast<void>(_ipex_result); // Avoid warnings in case not used |
| 201 | + return bridge::shallowUpgradeToDPCPPTensor(_ipex_result); |
189 | 202 | } |
190 | 203 |
|
191 | 204 | at::Tensor AtenIpexCPUDev::mkldnn_convolution(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) { |
|
0 commit comments