7575 'aten::_unsafe_view(Tensor self, int[] size) -> Tensor' ,
7676 'aten::native_layer_norm(Tensor input, Tensor? weight, Tensor? bias, int M, int N, float eps) -> (Tensor, Tensor, Tensor)' ,
7777 'aten::native_layer_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, int M, int N, bool[3] output_mask) -> (Tensor, Tensor, Tensor)' ,
78+ # 'aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)',
7879 'aten::_pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)'
7980]
8081
8182_FN_IPEX_FUNCS_WITH_SIMPLE_ATEN_SIG = [
8283 'aten::index_select(Tensor self, int dim, Tensor index) -> Tensor' ,
84+ # 'aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)',
8385 'aten::_pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)'
86+
8487]
8588
8689_SHALLOW_FALLBACK_TO_CPU_TENSOR_LIST = 'shallowFallbackToCPUTensorList'
@@ -330,7 +333,10 @@ def is_out_func(fname):
330333 if fname .endswith ('_' ):
331334 assert len (dnnl_tensor_param_vars ) > 0
332335 code += ' if (dbl::chk::dnnl_inplace_support_the_tensors(dnnl_input_tensors)) {\n '
333- code += ' return AtenIpexCPUDev::dil_{}({});\n ' .format (fname , ', ' .join (list (param_vars )))
336+ if self .is_ipex_func (aten_func_sig_str ):
337+ code += self .gen_ipex_func_code (fname , param_vars )
338+ else :
339+ code += ' return AtenIpexCPUDev::dil_{}({});\n ' .format (fname , ', ' .join (list (param_vars )))
334340 code += ' }\n ' # Check support tensors
335341 else :
336342 param_seq_str_vec = []
@@ -339,12 +345,7 @@ def is_out_func(fname):
339345 param_seq_str_vec .append (param_seq_str )
340346
341347 if self .is_ipex_func (aten_func_sig_str ):
342- code += ' auto _result = AtenIpexCPUDev::dil_{}({});\n ' .format (fname , ', ' .join (param_seq_str_vec ))
343- code += ' if (is_ipex_func_success()) {\n '
344- code += ' return _result;\n '
345- code += ' } else {\n '
346- code += ' reset_ipex_func_status();\n '
347- code += ' }\n '
348+ code += self .gen_ipex_func_code (fname , param_seq_str_vec )
348349 else :
349350 code += ' if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors)) {\n '
350351 code += ' return AtenIpexCPUDev::dil_{}({});\n ' .format (fname , ', ' .join (param_seq_str_vec ))
@@ -359,6 +360,16 @@ def is_out_func(fname):
359360
360361 return code
361362
363+ def gen_ipex_func_code (self , fname , param_vars ):
364+ code = ''
365+ code += ' auto _result = AtenIpexCPUDev::dil_{}({});\n ' .format (fname , ', ' .join (param_vars ))
366+ code += ' if (is_ipex_func_success()) {\n '
367+ code += ' return _result;\n '
368+ code += ' } else {\n '
369+ code += ' reset_ipex_func_status();\n '
370+ code += ' }\n '
371+ return code
372+
362373 def gen_fallback_prepare_code (self , cpp_sig ):
363374 code = ''
364375 op_check_code = ''
0 commit comments