Skip to content

Commit 8408ee0

Browse files
committed
Merge branch 'master' into fix_pack_padded_sequence
Conflicts: scripts/cpu/gen-dense-cpu-ops.py tests/cpu/test_bf16_lazy_reorder.py torch_ipex/csrc/cpu/DevOPs.cpp torch_ipex/csrc/cpu/DevOPs.h
2 parents ad240de + 1ee3050 commit 8408ee0

File tree

6 files changed

+929
-8
lines changed

6 files changed

+929
-8
lines changed

scripts/cpu/gen-dense-cpu-ops.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,15 @@
7575
'aten::_unsafe_view(Tensor self, int[] size) -> Tensor',
7676
'aten::native_layer_norm(Tensor input, Tensor? weight, Tensor? bias, int M, int N, float eps) -> (Tensor, Tensor, Tensor)',
7777
'aten::native_layer_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, int M, int N, bool[3] output_mask) -> (Tensor, Tensor, Tensor)',
78+
# 'aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)',
7879
'aten::_pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)'
7980
]
8081

8182
_FN_IPEX_FUNCS_WITH_SIMPLE_ATEN_SIG = [
8283
'aten::index_select(Tensor self, int dim, Tensor index) -> Tensor',
84+
# 'aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)',
8385
'aten::_pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)'
86+
8487
]
8588

8689
_SHALLOW_FALLBACK_TO_CPU_TENSOR_LIST = 'shallowFallbackToCPUTensorList'
@@ -330,7 +333,10 @@ def is_out_func(fname):
330333
if fname.endswith('_'):
331334
assert len(dnnl_tensor_param_vars) > 0
332335
code += ' if (dbl::chk::dnnl_inplace_support_the_tensors(dnnl_input_tensors)) {\n'
333-
code += ' return AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(list(param_vars)))
336+
if self.is_ipex_func(aten_func_sig_str):
337+
code += self.gen_ipex_func_code(fname, param_vars)
338+
else:
339+
code += ' return AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(list(param_vars)))
334340
code += ' }\n' # Check support tensors
335341
else:
336342
param_seq_str_vec = []
@@ -339,12 +345,7 @@ def is_out_func(fname):
339345
param_seq_str_vec.append(param_seq_str)
340346

341347
if self.is_ipex_func(aten_func_sig_str):
342-
code += ' auto _result = AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(param_seq_str_vec))
343-
code += ' if (is_ipex_func_success()) {\n'
344-
code += ' return _result;\n'
345-
code += ' } else {\n'
346-
code += ' reset_ipex_func_status();\n'
347-
code += ' }\n'
348+
code += self.gen_ipex_func_code(fname, param_seq_str_vec)
348349
else:
349350
code += ' if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors)) {\n'
350351
code += ' return AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(param_seq_str_vec))
@@ -359,6 +360,16 @@ def is_out_func(fname):
359360

360361
return code
361362

363+
def gen_ipex_func_code(self, fname, param_vars):
364+
code = ''
365+
code += ' auto _result = AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(param_vars))
366+
code += ' if (is_ipex_func_success()) {\n'
367+
code += ' return _result;\n'
368+
code += ' } else {\n'
369+
code += ' reset_ipex_func_status();\n'
370+
code += ' }\n'
371+
return code
372+
362373
def gen_fallback_prepare_code(self, cpp_sig):
363374
code = ''
364375
op_check_code = ''

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2063,5 +2063,31 @@ std::tuple<at::Tensor,at::Tensor> AtenIpexCPUDev::dil__pack_padded_sequence(cons
20632063
std::get<1>(_ipex_result));
20642064
}
20652065

2066+
at::Tensor& AtenIpexCPUDev::dil_copy_(
2067+
at::Tensor & self,
2068+
const at::Tensor & src,
2069+
bool non_blocking) {
2070+
DEBUG("AtenIpexCPUDev::dil_copy_\n");
2071+
torch_ipex::reset_ipex_func_status();
2072+
2073+
IPEX_CHECK(
2074+
self.device().type() == c10::DeviceType::DPCPP &&
2075+
src.device().type() == c10::DeviceType::DPCPP,
2076+
"IPEX copy only work on DPCPP tensor");
2077+
if (ShadeDataContext::isDilTensor(src) &&ShadeDataContext::isTensorMixPrecision(src)){
2078+
IPEX_CHECK(check_tensor_own_whole_storage(self), "IPEX copy only works while self tensor own the whole storage");
2079+
auto dil_src = dbl::comm::try_gen_dil_tensor(src);
2080+
IPEX_CHECK(dil_src.get_data_type() == dil::data_type::bf16)
2081+
auto new_buffer_desc = dil_src.get_desc();
2082+
dil::tensor dil_buffer{new_buffer_desc};
2083+
dil_src.reorder_to(dil_buffer);
2084+
dbl::comm::equip_dil_buffer(self, dil_buffer);
2085+
return self;
2086+
}
2087+
// TODO: We need add more LP here
2088+
torch_ipex::set_ipex_func_status(torch_ipex::IPEXFuncStatus::IPEX_FALLBACK);
2089+
return self;
2090+
}
2091+
20662092
} // namespace cpu
20672093
} // namespace torch_ipex

torch_ipex/csrc/cpu/DevOPs.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ class AtenIpexCPUDev {
8383
static at::Tensor dil__unsafe_view(const at::Tensor & self, at::IntArrayRef size);
8484
static at::Tensor dil_shuffle(const at::Tensor & self, at::IntArrayRef view_shape, int64_t dim0, int64_t dim1);
8585
static std::tuple<at::Tensor,at::Tensor> dil__pack_padded_sequence(const at::Tensor & input, const at::Tensor & lengths, bool batch_first);
86+
static at::Tensor& dil_copy_(at::Tensor & self, const at::Tensor & src, bool non_blocking);
87+
8688
};
8789

8890
} // namespace cpu

torch_ipex/csrc/cpu/bf16/Bridge.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ namespace bf16 {
1212

1313
at::Tensor gen_consistent_tensor(const at::Tensor & self) {
1414
// Reorder dil buffer to public because aten tensor does not support blocked format
15+
if (!ShadeDataContext::isDilTensor(self)){
16+
return bridge::shallowFallbackToCPUTensor(self);
17+
}
1518
dbl::comm::reorder_to_public(self, /*keep data type*/true);
1619

1720
dil::tensor& self_dil_storage = ShadeDataContext::getDilStorage(self);

torch_ipex/csrc/cpu/bf16/DevOPs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace bf16 {
1313

1414
at::Tensor index_select(const at::Tensor & self, int64_t dim, const at::Tensor & index) {
1515
auto&& _tensor = bf16::gen_consistent_tensor(self);
16-
auto&& _ipex_index = bridge::shallowFallbackToCPUTensor(index);
16+
auto&& _ipex_index = bf16::gen_consistent_tensor(index);
1717
auto&& _ipex_result = at::index_select(_tensor, dim, _ipex_index);
1818
return bf16::gen_mix_prec_tensor(_ipex_result);
1919
}

0 commit comments

Comments
 (0)