Skip to content

Commit ad240de

Browse files
committed
Fix pack_padded_sequence because the PyTorch restricts the batch size tensor shoudl be cpu
1 parent fda3b64 commit ad240de

File tree

7 files changed

+187
-124
lines changed

7 files changed

+187
-124
lines changed

scripts/cpu/gen-dense-cpu-ops.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,13 @@
7474
'aten::index_select(Tensor self, int dim, Tensor index) -> Tensor',
7575
'aten::_unsafe_view(Tensor self, int[] size) -> Tensor',
7676
'aten::native_layer_norm(Tensor input, Tensor? weight, Tensor? bias, int M, int N, float eps) -> (Tensor, Tensor, Tensor)',
77-
'aten::native_layer_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, int M, int N, bool[3] output_mask) -> (Tensor, Tensor, Tensor)'
77+
'aten::native_layer_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, int M, int N, bool[3] output_mask) -> (Tensor, Tensor, Tensor)',
78+
'aten::_pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)'
7879
]
7980

8081
_FN_IPEX_FUNCS_WITH_SIMPLE_ATEN_SIG = [
8182
'aten::index_select(Tensor self, int dim, Tensor index) -> Tensor',
83+
'aten::_pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)'
8284
]
8385

8486
_SHALLOW_FALLBACK_TO_CPU_TENSOR_LIST = 'shallowFallbackToCPUTensorList'
@@ -312,10 +314,17 @@ def is_out_func(fname):
312314
code += ' try {\n'
313315

314316
code += ' if (check_auto_dnnl()) {\n'
315-
code += ' std::vector<at::Tensor> dnnl_input_tensors;\n'
316-
if len(dnnl_tensor_param_vars) > 0:
317-
for dnnl_tensor_param_var in dnnl_tensor_param_vars:
318-
code += ' dnnl_input_tensors.push_back({});\n'.format(dnnl_tensor_param_var)
317+
318+
if not self.is_ipex_func(aten_func_sig_str):
319+
# There are two different kind of DevOPs in IPEX
320+
# 1. DNNL Operator
321+
# 2. CPU BF16/INT8 Operator in Vanilla PyTorch. IPEX itegrates this kind of operators in IPEX for
322+
# mixture precision.
323+
# For the type 2, IPEX does not need to check if DNNL supports these tensors.
324+
code += ' std::vector<at::Tensor> dnnl_input_tensors;\n'
325+
if len(dnnl_tensor_param_vars) > 0:
326+
for dnnl_tensor_param_var in dnnl_tensor_param_vars:
327+
code += ' dnnl_input_tensors.push_back({});\n'.format(dnnl_tensor_param_var)
319328

320329
fname = cpp_sig.def_name
321330
if fname.endswith('_'):
@@ -328,19 +337,18 @@ def is_out_func(fname):
328337
for param_var in param_vars:
329338
param_seq_str = param_var
330339
param_seq_str_vec.append(param_seq_str)
331-
code += ' if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors)) {\n'
332340

333341
if self.is_ipex_func(aten_func_sig_str):
334-
code += ' auto _result = AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(param_seq_str_vec))
335-
code += ' if (is_ipex_func_success()) {\n'
336-
code += ' return _result;\n'
337-
code += ' } else {\n'
338-
code += ' reset_ipex_func_status();\n'
339-
code += ' }\n'
342+
code += ' auto _result = AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(param_seq_str_vec))
343+
code += ' if (is_ipex_func_success()) {\n'
344+
code += ' return _result;\n'
345+
code += ' } else {\n'
346+
code += ' reset_ipex_func_status();\n'
347+
code += ' }\n'
340348
else:
349+
code += ' if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors)) {\n'
341350
code += ' return AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(param_seq_str_vec))
342-
343-
code += ' }\n' # Check support tensors
351+
code += ' }\n' # Check support tensors
344352
code += ' }\n' # Check auto dnnl
345353
code += ' } catch (std::exception& e) {\n'
346354
code += '#if defined(_DEBUG)\n'

0 commit comments

Comments
 (0)