Skip to content

Commit 5ce9c59

Browse files
authored
Wrong shape if the tensor is mixed precision (#100)
If the shape of dil tensor is not as same as its tensor wrapper, then the tensor wrapper will be incorrect because the reorder syncs the shape of dil tensor to its tensor wrapper. In fact, we should only take the dil tensor as the buffer, we should not never sync the meta info of dil tensor to its tensor wrapper. It is JUST a buffer.
1 parent 4f7b02b commit 5ce9c59

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

tests/cpu/test_lazy_reorder.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,29 @@ def test_view(self):
10821082
out_dpcpp = x_dpcpp * y
10831083
self.assertEqual(out_cpu, out_dpcpp)
10841084

1085+
with AutoMixPrecision(True):
1086+
# test share storage for view
1087+
src_1 = torch.randn(5120, 1, 128, device=device)
1088+
src_2 = torch.randn(5120, 1, 128, device=device)
1089+
res_bf16 = src_1 + src_2
1090+
res_bf16_other = src_1 + src_2
1091+
self.assertTrue(ipex.core.is_dil_tensor(res_bf16))
1092+
self.assertTrue(ipex.core.is_bf16_dil_tensor(res_bf16))
1093+
self.assertTrue(ipex.core.get_dil_tensor_sizes(res_bf16), [5120, 1, 128])
1094+
self.assertEqual(list(res_bf16.size()), [5120, 1, 128])
1095+
res_fp32_view = res_bf16.view(1280, 4, 1, 128)
1096+
self.assertTrue(ipex.core.is_dil_tensor(res_bf16))
1097+
self.assertTrue(ipex.core.is_dil_tensor(res_fp32_view))
1098+
self.assertFalse(ipex.core.is_bf16_dil_tensor(res_bf16))
1099+
self.assertFalse(ipex.core.is_bf16_dil_tensor(res_fp32_view))
1100+
self.assertEqual(list(res_fp32_view.size()), [1280, 4, 1, 128])
1101+
tmp_res = res_bf16 + res_bf16_other
1102+
self.assertTrue(ipex.core.is_bf16_dil_tensor(res_bf16))
1103+
self.assertTrue(ipex.core.is_bf16_dil_tensor(res_fp32_view))
1104+
tmp_res = res_fp32_view.index_select(0, torch.LongTensor([0, 1]))
1105+
self.assertTrue(ipex.core.get_dil_tensor_sizes(res_fp32_view), [5120, 1, 128])
1106+
self.assertTrue(ipex.core.get_dil_tensor_sizes(res_fp32_view), [5120, 1, 128])
1107+
self.assertEqual(list(tmp_res.size()), [2, 4, 1, 128])
10851108

10861109
class TestSoftMax(TestCase):
10871110
def test_softmax(self):

torch_ipex/csrc/aten_ipex_bridge.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ void reorderDilTensorToPublic(const at::Tensor& ipexTensor) {
9393
auto aten_tensor_scalar_type = ipexTensor.scalar_type();
9494
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(aten_tensor_scalar_type == at::kFloat || aten_tensor_scalar_type == at::kBFloat16);
9595
pub_tensor = dil_tensor.to_public(nullptr, get_dil_data_type(aten_tensor_scalar_type));
96+
cpu::dbl::comm::sync_shape_from_dil_to_aten(ipexTensor, pub_tensor);
9697
}
9798

9899
if (!pub_tensor.is_empty()) {
@@ -114,7 +115,6 @@ void reorderDilTensorToPublic(const at::Tensor& ipexTensor) {
114115
ipexTensor.device().type());
115116

116117
ipexTensor.unsafeGetTensorImpl()->storage().set_data_ptr(std::move(shade_data_ptr));
117-
cpu::dbl::comm::sync_shape_from_dil_to_aten(ipexTensor, pub_tensor);
118118
}
119119
}
120120

0 commit comments

Comments
 (0)