7474 'aten::index_select(Tensor self, int dim, Tensor index) -> Tensor' ,
7575 'aten::_unsafe_view(Tensor self, int[] size) -> Tensor' ,
7676 'aten::native_layer_norm(Tensor input, Tensor? weight, Tensor? bias, int M, int N, float eps) -> (Tensor, Tensor, Tensor)' ,
77- 'aten::native_layer_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, int M, int N, bool[3] output_mask) -> (Tensor, Tensor, Tensor)'
77+ 'aten::native_layer_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, int M, int N, bool[3] output_mask) -> (Tensor, Tensor, Tensor)' ,
78+ 'aten::_pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)'
7879]
7980
8081_FN_IPEX_FUNCS_WITH_SIMPLE_ATEN_SIG = [
8182 'aten::index_select(Tensor self, int dim, Tensor index) -> Tensor' ,
83+ 'aten::_pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)'
8284]
8385
8486_SHALLOW_FALLBACK_TO_CPU_TENSOR_LIST = 'shallowFallbackToCPUTensorList'
@@ -312,10 +314,17 @@ def is_out_func(fname):
312314 code += ' try {\n '
313315
314316 code += ' if (check_auto_dnnl()) {\n '
315- code += ' std::vector<at::Tensor> dnnl_input_tensors;\n '
316- if len (dnnl_tensor_param_vars ) > 0 :
317- for dnnl_tensor_param_var in dnnl_tensor_param_vars :
318- code += ' dnnl_input_tensors.push_back({});\n ' .format (dnnl_tensor_param_var )
317+
318+ if not self .is_ipex_func (aten_func_sig_str ):
319+ # There are two different kind of DevOPs in IPEX
320+ # 1. DNNL Operator
321+ # 2. CPU BF16/INT8 Operator in Vanilla PyTorch. IPEX itegrates this kind of operators in IPEX for
322+ # mixture precision.
323+ # For the type 2, IPEX does not need to check if DNNL supports these tensors.
324+ code += ' std::vector<at::Tensor> dnnl_input_tensors;\n '
325+ if len (dnnl_tensor_param_vars ) > 0 :
326+ for dnnl_tensor_param_var in dnnl_tensor_param_vars :
327+ code += ' dnnl_input_tensors.push_back({});\n ' .format (dnnl_tensor_param_var )
319328
320329 fname = cpp_sig .def_name
321330 if fname .endswith ('_' ):
@@ -328,19 +337,18 @@ def is_out_func(fname):
328337 for param_var in param_vars :
329338 param_seq_str = param_var
330339 param_seq_str_vec .append (param_seq_str )
331- code += ' if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors)) {\n '
332340
333341 if self .is_ipex_func (aten_func_sig_str ):
334- code += ' auto _result = AtenIpexCPUDev::dil_{}({});\n ' .format (fname , ', ' .join (param_seq_str_vec ))
335- code += ' if (is_ipex_func_success()) {\n '
336- code += ' return _result;\n '
337- code += ' } else {\n '
338- code += ' reset_ipex_func_status();\n '
339- code += ' }\n '
342+ code += ' auto _result = AtenIpexCPUDev::dil_{}({});\n ' .format (fname , ', ' .join (param_seq_str_vec ))
343+ code += ' if (is_ipex_func_success()) {\n '
344+ code += ' return _result;\n '
345+ code += ' } else {\n '
346+ code += ' reset_ipex_func_status();\n '
347+ code += ' }\n '
340348 else :
349+ code += ' if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors)) {\n '
341350 code += ' return AtenIpexCPUDev::dil_{}({});\n ' .format (fname , ', ' .join (param_seq_str_vec ))
342-
343- code += ' }\n ' # Check support tensors
351+ code += ' }\n ' # Check support tensors
344352 code += ' }\n ' # Check auto dnnl
345353 code += ' } catch (std::exception& e) {\n '
346354 code += '#if defined(_DEBUG)\n '
0 commit comments