Skip to content

Commit e68a649

Browse files
author
haozhe.zhu
committed
enable bf16 copy_
1 parent 264ceaa commit e68a649

File tree

6 files changed

+75
-9
lines changed

6 files changed

+75
-9
lines changed

scripts/cpu/gen-dense-cpu-ops.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,13 @@
7474
'aten::index_select(Tensor self, int dim, Tensor index) -> Tensor',
7575
'aten::_unsafe_view(Tensor self, int[] size) -> Tensor',
7676
'aten::native_layer_norm(Tensor input, Tensor? weight, Tensor? bias, int M, int N, float eps) -> (Tensor, Tensor, Tensor)',
77-
'aten::native_layer_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, int M, int N, bool[3] output_mask) -> (Tensor, Tensor, Tensor)'
77+
'aten::native_layer_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, int M, int N, bool[3] output_mask) -> (Tensor, Tensor, Tensor)',
78+
'aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)',
7879
]
7980

8081
_FN_IPEX_FUNCS_WITH_SIMPLE_ATEN_SIG = [
8182
'aten::index_select(Tensor self, int dim, Tensor index) -> Tensor',
83+
'aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)',
8284
]
8385

8486
_SHALLOW_FALLBACK_TO_CPU_TENSOR_LIST = 'shallowFallbackToCPUTensorList'
@@ -321,7 +323,10 @@ def is_out_func(fname):
321323
if fname.endswith('_'):
322324
assert len(dnnl_tensor_param_vars) > 0
323325
code += ' if (dbl::chk::dnnl_inplace_support_the_tensors(dnnl_input_tensors)) {\n'
324-
code += ' return AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(list(param_vars)))
326+
if self.is_ipex_func(aten_func_sig_str):
327+
code += self.gen_ipex_func_code(fname, param_vars)
328+
else:
329+
code += ' return AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(list(param_vars)))
325330
code += ' }\n' # Check support tensors
326331
else:
327332
param_seq_str_vec = []
@@ -331,12 +336,7 @@ def is_out_func(fname):
331336
code += ' if (dbl::chk::dnnl_support_the_tensors(dnnl_input_tensors)) {\n'
332337

333338
if self.is_ipex_func(aten_func_sig_str):
334-
code += ' auto _result = AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(param_seq_str_vec))
335-
code += ' if (is_ipex_func_success()) {\n'
336-
code += ' return _result;\n'
337-
code += ' } else {\n'
338-
code += ' reset_ipex_func_status();\n'
339-
code += ' }\n'
339+
code += self.gen_ipex_func_code(fname, param_seq_str_vec)
340340
else:
341341
code += ' return AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(param_seq_str_vec))
342342

@@ -351,6 +351,16 @@ def is_out_func(fname):
351351

352352
return code
353353

354+
def gen_ipex_func_code(self, fname, param_vars):
355+
code = ''
356+
code += ' auto _result = AtenIpexCPUDev::dil_{}({});\n'.format(fname, ', '.join(param_vars))
357+
code += ' if (is_ipex_func_success()) {\n'
358+
code += ' return _result;\n'
359+
code += ' } else {\n'
360+
code += ' reset_ipex_func_status();\n'
361+
code += ' }\n'
362+
return code
363+
354364
def gen_fallback_prepare_code(self, cpp_sig):
355365
code = ''
356366
op_check_code = ''

tests/cpu/test_bf16_lazy_reorder.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1956,6 +1956,31 @@ def test_save_and_load(self):
19561956
torch.save(output_dpcpp, 'tensor_dpcpp.pt')
19571957
self.assertEqual(torch.load('tensor.pt'), torch.load('tensor_dpcpp.pt'))
19581958

1959+
class TestCopy_(TestCase):
1960+
def test_copy_(self):
1961+
rand_seed = int(get_rand_seed())
1962+
print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
1963+
torch.manual_seed(rand_seed)
1964+
self_auto_mix = torch.randn(3, 4, 5, dtype=torch.float32, device=device) * 10
1965+
self_man_mix = (torch.randn(3, 4, 5, device=device) * 10).to(torch.bfloat16)
1966+
src_auto_mix = torch.randn(3, 4, 5, dtype=torch.float32, device=device) * 10
1967+
copy_src_auto_mix = copy.deepcopy(src_auto_mix).to(device=device)
1968+
copy_src_man_mix = copy.deepcopy(src_auto_mix).to(device=device).to(torch.bfloat16)
1969+
1970+
with AutoDNNL(True), AutoMixPrecision(False):
1971+
res_man_bf16 = copy_src_man_mix + copy_src_man_mix
1972+
self.assertEqual(res_man_bf16.dtype, torch.bfloat16)
1973+
self_man_mix.copy_(res_man_bf16)
1974+
self.assertEqual(self_man_mix.dtype, torch.bfloat16)
1975+
1976+
with AutoMixPrecision(True):
1977+
res_auto_mix = copy_src_auto_mix + copy_src_auto_mix
1978+
self.assertEqual(res_auto_mix.dtype, torch.float)
1979+
self.assertTrue(ipex.core.is_bf16_dil_tensor(res_auto_mix))
1980+
self_auto_mix.copy_(res_auto_mix)
1981+
self.assertTrue(ipex.core.is_bf16_dil_tensor(self_auto_mix))
1982+
self.assertEqual(self_auto_mix.dtype, torch.float)
1983+
self.assertEqual(self_auto_mix, self_man_mix.float())
19591984

19601985
if __name__ == '__main__':
19611986
test = unittest.main()

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2042,5 +2042,31 @@ at::Tensor AtenIpexCPUDev::dil_shuffle(const at::Tensor & self, at::IntArrayRef
20422042
return dbl::comm::gen_aten_tensor_by(std::move(y));
20432043
}
20442044

2045+
at::Tensor& AtenIpexCPUDev::dil_copy_(
2046+
at::Tensor & self,
2047+
const at::Tensor & src,
2048+
bool non_blocking) {
2049+
DEBUG("AtenIpexCPUDev::dil_copy_\n");
2050+
torch_ipex::reset_ipex_func_status();
2051+
2052+
IPEX_CHECK(
2053+
self.device().type() == c10::DeviceType::DPCPP &&
2054+
src.device().type() == c10::DeviceType::DPCPP,
2055+
"IPEX copy only work on DPCPP tensor");
2056+
if (ShadeDataContext::isDilTensor(src) &&ShadeDataContext::isTensorMixPrecision(src)){
2057+
IPEX_CHECK(check_tensor_own_whole_storage(self), "IPEX copy only works while self tensor own the whole storage");
2058+
auto dil_src = dbl::comm::try_gen_dil_tensor(src);
2059+
IPEX_CHECK(dil_src.get_data_type() == dil::data_type::bf16)
2060+
auto new_buffer_desc = dil_src.get_desc();
2061+
dil::tensor dil_buffer{new_buffer_desc};
2062+
dil_src.reorder_to(dil_buffer);
2063+
dbl::comm::equip_dil_buffer(self, dil_buffer);
2064+
return self;
2065+
}
2066+
// TODO: We need add more LP here
2067+
torch_ipex::set_ipex_func_status(torch_ipex::IPEXFuncStatus::IPEX_FALLBACK);
2068+
return self;
2069+
}
2070+
20452071
} // namespace cpu
20462072
} // namespace torch_ipex

torch_ipex/csrc/cpu/DevOPs.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ class AtenIpexCPUDev {
8282
static at::Tensor dil_index_select(const at::Tensor & self, int64_t dim, const at::Tensor & index);
8383
static at::Tensor dil__unsafe_view(const at::Tensor & self, at::IntArrayRef size);
8484
static at::Tensor dil_shuffle(const at::Tensor & self, at::IntArrayRef view_shape, int64_t dim0, int64_t dim1);
85+
static at::Tensor& dil_copy_(at::Tensor & self, const at::Tensor & src, bool non_blocking);
86+
8587
};
8688

8789
} // namespace cpu

torch_ipex/csrc/cpu/bf16/Bridge.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ namespace bf16 {
1212

1313
at::Tensor gen_consistent_tensor(const at::Tensor & self) {
1414
// Reorder dil buffer to public because aten tensor does not support blocked format
15+
if (!ShadeDataContext::isDilTensor(self)){
16+
return bridge::shallowFallbackToCPUTensor(self);
17+
}
1518
dbl::comm::reorder_to_public(self, /*keep data type*/true);
1619

1720
dil::tensor& self_dil_storage = ShadeDataContext::getDilStorage(self);

torch_ipex/csrc/cpu/bf16/DevOPs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace bf16 {
1313

1414
at::Tensor index_select(const at::Tensor & self, int64_t dim, const at::Tensor & index) {
1515
auto&& _tensor = bf16::gen_consistent_tensor(self);
16-
auto&& _ipex_index = bridge::shallowFallbackToCPUTensor(index);
16+
auto&& _ipex_index = bf16::gen_consistent_tensor(index);
1717
auto&& _ipex_result = at::index_select(_tensor, dim, _ipex_index);
1818
return bf16::gen_mix_prec_tensor(_ipex_result);
1919
}

0 commit comments

Comments
 (0)