@@ -48,52 +48,28 @@ void reorder_to_bf16_for_mix_prec(const at::Tensor& tensor) {
4848}
4949
5050void reorder_to_dtype (const at::Tensor& tensor, at::ScalarType dst_scalar_type) {
51- dil::tensor::memory::desc dst_desc;
52- if (check_tensor_own_shade_context (tensor) && cpu::ShadeDataContext::isDilOwnTheTensor (tensor)) {
53- // The buffer ownership is DIL
54- cpu::ShadeDataContext *shade_context = (cpu::ShadeDataContext*)(tensor.storage ().data_ptr ().get_context ());
55- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (shade_context->dil_tensor .has_value ());
56- dil::tensor&& dil_tensor = std::move (shade_context->dil_tensor .value ());
57- if (get_at_data_type (dil_tensor.get_data_type ()) == dst_scalar_type) {
58- // The data type of DIL tensor is same as the destination data type. DO NOTHING
59- return ;
60- }
61-
62- TORCH_CHECK (check_tensor_own_whole_storage (tensor),
63- " Intel Extension for PyTorch does not support the data is just a part of its storage for auto mix-precision." );
64- dst_desc = dil_tensor.get_desc ().to_type (get_dil_data_type (dst_scalar_type));
65- } else {
66- // The buffer ownership is CPU
67- TORCH_CHECK (check_tensor_own_whole_storage (tensor),
68- " Intel Extension for PyTorch does not support the data is just a part of its storage for auto mix-precision." );
69- dil::tensor temp_dil_tensor = cpu::dbl::comm::dil_tensor_from_dense (tensor);
70- dst_desc = temp_dil_tensor.get_desc ().to_type (get_dil_data_type (dst_scalar_type));
51+ auto src = try_gen_dil_tensor (tensor);
52+ if (get_at_data_type (src.get_data_type ()) == dst_scalar_type) {
53+ // The data type of DIL tensor is same as the dst data type. DO NOTHING
54+ return ;
7155 }
56+ auto dst_desc = src.get_desc ().to_type (get_dil_data_type (dst_scalar_type));
7257 reorder_to_desc (tensor, dst_desc);
7358}
7459
7560void reorder_to_desc (const at::Tensor& tensor, const dil::tensor::desc& expected_desc) {
76- dil::tensor new_type_dil_tensor;
77- new_type_dil_tensor.init (expected_desc);
78-
79- bool contains_dil_tensor_buffer = check_tensor_own_shade_context (tensor) && cpu::ShadeDataContext::isDilOwnTheTensor (tensor);
80- if (contains_dil_tensor_buffer) {
81- cpu::ShadeDataContext *shade_context = (cpu::ShadeDataContext*)(tensor.storage ().data_ptr ().get_context ());
82- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (shade_context->dil_tensor .has_value ());
83- new_type_dil_tensor.feed_from (shade_context->dil_tensor .value ());
84- } else {
85- new_type_dil_tensor.feed_from (cpu::dbl::comm::dil_tensor_from_dense (tensor));
86- }
87-
88- equip_dil_buffer (tensor, new_type_dil_tensor);
61+ auto src = try_gen_dil_tensor (tensor);
62+ dil::tensor dst {expected_desc};
63+ dst.feed_from (src);
64+ equip_dil_buffer (tensor, dst);
8965}
9066
9167void equip_dil_buffer (const at::Tensor& tensor, dil::tensor dil_tensor_buffer) {
92- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (
68+ TORCH_CHECK (
9369 tensor.device ().is_dpcpp (),
9470 " dil buffer can only be equipped to dpcpp tensor" );
9571
96- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (
72+ TORCH_CHECK (
9773 check_tensor_own_whole_storage (tensor),
9874 " dil buffer can only be equipped to tensors that own the whole storage, "
9975 " as dil buffer is going to replace the original storage" );
0 commit comments