Skip to content

Commit c428ed4

Browse files
authored
fix shape mismatch (#91)
1 parent 1b16e38 commit c428ed4

File tree

1 file changed

+11
-35
lines changed

1 file changed

+11
-35
lines changed

torch_ipex/csrc/cpu/dbl/Common.cpp

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -48,52 +48,28 @@ void reorder_to_bf16_for_mix_prec(const at::Tensor& tensor) {
4848
}
4949

5050
void reorder_to_dtype(const at::Tensor& tensor, at::ScalarType dst_scalar_type) {
51-
dil::tensor::memory::desc dst_desc;
52-
if (check_tensor_own_shade_context(tensor) && cpu::ShadeDataContext::isDilOwnTheTensor(tensor)) {
53-
// The buffer ownership is DIL
54-
cpu::ShadeDataContext *shade_context = (cpu::ShadeDataContext*)(tensor.storage().data_ptr().get_context());
55-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_context->dil_tensor.has_value());
56-
dil::tensor&& dil_tensor = std::move(shade_context->dil_tensor.value());
57-
if (get_at_data_type(dil_tensor.get_data_type()) == dst_scalar_type) {
58-
// The data type of DIL tensor is same as the destination data type. DO NOTHING
59-
return;
60-
}
61-
62-
TORCH_CHECK(check_tensor_own_whole_storage(tensor),
63-
"Intel Extension for PyTorch does not support the data is just a part of its storage for auto mix-precision.");
64-
dst_desc = dil_tensor.get_desc().to_type(get_dil_data_type(dst_scalar_type));
65-
} else {
66-
// The buffer ownership is CPU
67-
TORCH_CHECK(check_tensor_own_whole_storage(tensor),
68-
"Intel Extension for PyTorch does not support the data is just a part of its storage for auto mix-precision.");
69-
dil::tensor temp_dil_tensor = cpu::dbl::comm::dil_tensor_from_dense(tensor);
70-
dst_desc = temp_dil_tensor.get_desc().to_type(get_dil_data_type(dst_scalar_type));
51+
auto src = try_gen_dil_tensor(tensor);
52+
if (get_at_data_type(src.get_data_type()) == dst_scalar_type) {
53+
// The data type of DIL tensor is same as the dst data type. DO NOTHING
54+
return;
7155
}
56+
auto dst_desc = src.get_desc().to_type(get_dil_data_type(dst_scalar_type));
7257
reorder_to_desc(tensor, dst_desc);
7358
}
7459

7560
void reorder_to_desc(const at::Tensor& tensor, const dil::tensor::desc& expected_desc) {
76-
dil::tensor new_type_dil_tensor;
77-
new_type_dil_tensor.init(expected_desc);
78-
79-
bool contains_dil_tensor_buffer = check_tensor_own_shade_context(tensor) && cpu::ShadeDataContext::isDilOwnTheTensor(tensor);
80-
if (contains_dil_tensor_buffer) {
81-
cpu::ShadeDataContext *shade_context = (cpu::ShadeDataContext*)(tensor.storage().data_ptr().get_context());
82-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(shade_context->dil_tensor.has_value());
83-
new_type_dil_tensor.feed_from(shade_context->dil_tensor.value());
84-
} else {
85-
new_type_dil_tensor.feed_from(cpu::dbl::comm::dil_tensor_from_dense(tensor));
86-
}
87-
88-
equip_dil_buffer(tensor, new_type_dil_tensor);
61+
auto src = try_gen_dil_tensor(tensor);
62+
dil::tensor dst {expected_desc};
63+
dst.feed_from(src);
64+
equip_dil_buffer(tensor, dst);
8965
}
9066

9167
void equip_dil_buffer(const at::Tensor& tensor, dil::tensor dil_tensor_buffer) {
92-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
68+
TORCH_CHECK(
9369
tensor.device().is_dpcpp(),
9470
"dil buffer can only be equipped to dpcpp tensor");
9571

96-
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
72+
TORCH_CHECK(
9773
check_tensor_own_whole_storage(tensor),
9874
"dil buffer can only be equipped to tensors that own the whole storage, "
9975
"as dil buffer is going to replace the original storage");

0 commit comments

Comments
 (0)