Skip to content

Commit 2a41250

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Allow data casting along with cloning. (#15510)
Summary: . Differential Revision: D86070966
1 parent a11d555 commit 2a41250

File tree

3 files changed

+164
-14
lines changed

3 files changed

+164
-14
lines changed

extension/tensor/tensor_ptr.cpp

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,9 @@ TensorPtr make_tensor_ptr(
164164
[data = std::move(data)](void*) {});
165165
}
166166

167-
TensorPtr clone_tensor_ptr(const executorch::aten::Tensor& tensor) {
167+
TensorPtr clone_tensor_ptr(
168+
const executorch::aten::Tensor& tensor,
169+
executorch::aten::ScalarType type) {
168170
std::vector<executorch::aten::SizesType> sizes(
169171
tensor.sizes().begin(), tensor.sizes().end());
170172
std::vector<executorch::aten::DimOrderType> dim_order{
@@ -178,23 +180,61 @@ TensorPtr clone_tensor_ptr(const executorch::aten::Tensor& tensor) {
178180
#ifndef USE_ATEN_LIB
179181
dynamism = tensor.shape_dynamism();
180182
#endif // USE_ATEN_LIB
181-
return tensor.const_data_ptr()
182-
? make_tensor_ptr(
183+
const auto* tensor_data = tensor.const_data_ptr();
184+
if (!tensor_data) {
185+
return make_tensor_ptr(
183186
std::move(sizes),
184-
std::vector<uint8_t>(
185-
(uint8_t*)tensor.const_data_ptr(),
186-
(uint8_t*)tensor.const_data_ptr() + tensor.nbytes()),
187+
nullptr,
187188
std::move(dim_order),
188189
std::move(strides),
189-
tensor.scalar_type(),
190-
dynamism)
191-
: make_tensor_ptr(
190+
type,
191+
dynamism);
192+
}
193+
const auto tensor_type = tensor.scalar_type();
194+
if (tensor_type == type) {
195+
return make_tensor_ptr(
192196
std::move(sizes),
193-
nullptr,
197+
std::vector<uint8_t>(
198+
(uint8_t*)tensor_data,
199+
(uint8_t*)tensor_data + tensor.nbytes()),
194200
std::move(dim_order),
195201
std::move(strides),
196-
tensor.scalar_type(),
202+
tensor_type,
197203
dynamism);
204+
}
205+
ET_CHECK_MSG(
206+
runtime::canCast(tensor_type, type),
207+
"Cannot cast tensor type to desired type.");
208+
const auto tensor_numel = static_cast<size_t>(tensor.numel());
209+
std::vector<uint8_t> data(tensor_numel * aten::elementSize(type));
210+
211+
// Create a minimal context for error handling in ET_SWITCH
212+
struct {
213+
[[noreturn]] void fail(torch::executor::Error /* error */) {
214+
ET_CHECK_MSG(false, "Unsupported dtype in clone_tensor_ptr");
215+
}
216+
} ctx;
217+
218+
ET_SWITCH_REALHBBF16_AND_UINT_TYPES(
219+
tensor_type, ctx, "clone_tensor_ptr_from", CTYPE_FROM, [&] {
220+
const CTYPE_FROM* tensor_data_ptr = static_cast<const CTYPE_FROM*>(tensor_data);
221+
ET_SWITCH_REALHBBF16_AND_UINT_TYPES(
222+
type, ctx, "clone_tensor_ptr_to", CTYPE_TO, [&] {
223+
CTYPE_TO* data_ptr = reinterpret_cast<CTYPE_TO*>(data.data());
224+
std::transform(
225+
tensor_data_ptr,
226+
tensor_data_ptr + tensor_numel,
227+
data_ptr,
228+
[](const CTYPE_FROM& val) { return static_cast<CTYPE_TO>(val); });
229+
});
230+
});
231+
return make_tensor_ptr(
232+
std::move(sizes),
233+
std::move(data),
234+
std::move(dim_order),
235+
std::move(strides),
236+
type,
237+
dynamism);
198238
}
199239

200240
runtime::Error resize_tensor_ptr(

extension/tensor/tensor_ptr.h

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ inline TensorPtr make_tensor_ptr(
114114
ET_CHECK_MSG(
115115
runtime::canCast(deduced_type, type),
116116
"Cannot cast deduced type to specified type.");
117-
std::vector<uint8_t> casted_data(data.size() * runtime::elementSize(type));
117+
std::vector<uint8_t> casted_data(data.size() * aten::elementSize(type));
118118

119119
// Create a minimal context for error handling in ET_SWITCH
120120
struct {
@@ -408,6 +408,21 @@ inline TensorPtr make_tensor_ptr(
408408
[tensor_ptr](void*) {});
409409
}
410410

411+
/**
412+
* Creates a TensorPtr that manages a new Tensor with the same properties
413+
* as the given Tensor, but with a copy of the data owned by the returned
414+
* TensorPtr, or nullptr if the original data is null.
415+
*
416+
* @param tensor The Tensor to clone.
417+
* @param type The data type for the cloned tensor. The data will be cast
418+
* from the source tensor's type.
419+
* @return A new TensorPtr that manages a Tensor with the specified type
420+
* and copied/cast data.
421+
*/
422+
TensorPtr clone_tensor_ptr(
423+
const executorch::aten::Tensor& tensor,
424+
executorch::aten::ScalarType type);
425+
411426
/**
412427
* Creates a TensorPtr that manages a new Tensor with the same properties
413428
* as the given Tensor, but with a copy of the data owned by the returned
@@ -417,7 +432,25 @@ inline TensorPtr make_tensor_ptr(
417432
* @return A new TensorPtr that manages a Tensor with the same properties as the
418433
* original but with copied data.
419434
*/
420-
TensorPtr clone_tensor_ptr(const executorch::aten::Tensor& tensor);
435+
inline TensorPtr clone_tensor_ptr(const executorch::aten::Tensor& tensor) {
436+
return clone_tensor_ptr(tensor, tensor.scalar_type());
437+
}
438+
439+
/**
440+
* Creates a new TensorPtr by cloning the given TensorPtr, copying the
441+
* underlying data.
442+
*
443+
* @param tensor The TensorPtr to clone.
444+
* @param type The data type for the cloned tensor. The data will be cast
445+
* from the source tensor's type.
446+
* @return A new TensorPtr that manages a Tensor with the specified type
447+
* and copied/cast data.
448+
*/
449+
inline TensorPtr clone_tensor_ptr(
450+
const TensorPtr& tensor,
451+
executorch::aten::ScalarType type) {
452+
return clone_tensor_ptr(*tensor, type);
453+
}
421454

422455
/**
423456
* Creates a new TensorPtr by cloning the given TensorPtr, copying the
@@ -428,7 +461,7 @@ TensorPtr clone_tensor_ptr(const executorch::aten::Tensor& tensor);
428461
* original but with copied data.
429462
*/
430463
inline TensorPtr clone_tensor_ptr(const TensorPtr& tensor) {
431-
return clone_tensor_ptr(*tensor);
464+
return clone_tensor_ptr(*tensor, tensor->scalar_type());
432465
}
433466

434467
/**

extension/tensor/test/tensor_ptr_test.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,83 @@ TEST_F(TensorPtrTest, CloneTensorPtrFromExistingTensorInt32) {
571571
EXPECT_EQ(cloned_tensor->scalar_type(), executorch::aten::ScalarType::Int);
572572
}
573573

574+
TEST_F(TensorPtrTest, CloneTensorPtrCastInt32ToFloat) {
575+
std::vector<int32_t> data = {1, 2, 3, 4};
576+
auto tensor = make_tensor_ptr({2, 2}, std::move(data));
577+
auto cloned_tensor =
578+
clone_tensor_ptr(*tensor, executorch::aten::ScalarType::Float);
579+
580+
EXPECT_EQ(cloned_tensor->dim(), 2);
581+
EXPECT_EQ(cloned_tensor->size(0), 2);
582+
EXPECT_EQ(cloned_tensor->size(1), 2);
583+
EXPECT_EQ(cloned_tensor->scalar_type(), executorch::aten::ScalarType::Float);
584+
auto ptr = cloned_tensor->const_data_ptr<float>();
585+
EXPECT_FLOAT_EQ(ptr[0], 1.0f);
586+
EXPECT_FLOAT_EQ(ptr[1], 2.0f);
587+
EXPECT_FLOAT_EQ(ptr[2], 3.0f);
588+
EXPECT_FLOAT_EQ(ptr[3], 4.0f);
589+
}
590+
591+
TEST_F(TensorPtrTest, CloneTensorPtrCastFloatToBFloat16) {
592+
std::vector<float> data = {1.0f, 2.0f, 3.5f};
593+
auto tensor = make_tensor_ptr({3}, std::move(data));
594+
auto cloned_tensor =
595+
clone_tensor_ptr(*tensor, executorch::aten::ScalarType::BFloat16);
596+
597+
EXPECT_EQ(cloned_tensor->dim(), 1);
598+
EXPECT_EQ(cloned_tensor->size(0), 3);
599+
EXPECT_EQ(
600+
cloned_tensor->scalar_type(), executorch::aten::ScalarType::BFloat16);
601+
auto ptr = cloned_tensor->const_data_ptr<executorch::aten::BFloat16>();
602+
EXPECT_NEAR(static_cast<float>(ptr[0]), 1.0f, 0.01f);
603+
EXPECT_NEAR(static_cast<float>(ptr[1]), 2.0f, 0.01f);
604+
EXPECT_NEAR(static_cast<float>(ptr[2]), 3.5f, 0.01f);
605+
}
606+
607+
TEST_F(TensorPtrTest, CloneTensorPtrCastKeepsMetadata) {
608+
std::vector<uint8_t> data(
609+
6 * executorch::aten::elementSize(executorch::aten::ScalarType::Float));
610+
auto tensor = make_tensor_ptr({2, 3}, std::move(data));
611+
auto cloned_tensor =
612+
clone_tensor_ptr(*tensor, executorch::aten::ScalarType::Float);
613+
614+
EXPECT_EQ(cloned_tensor->dim(), 2);
615+
EXPECT_EQ(cloned_tensor->size(0), 2);
616+
EXPECT_EQ(cloned_tensor->size(1), 3);
617+
EXPECT_EQ(cloned_tensor->strides()[0], 3);
618+
EXPECT_EQ(cloned_tensor->strides()[1], 1);
619+
EXPECT_EQ(cloned_tensor->scalar_type(), executorch::aten::ScalarType::Float);
620+
}
621+
622+
TEST_F(TensorPtrTest, CloneTensorPtrCastNullData) {
623+
auto tensor = make_tensor_ptr(
624+
{2, 2},
625+
nullptr,
626+
{},
627+
{},
628+
executorch::aten::ScalarType::Float,
629+
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND);
630+
auto cloned_tensor =
631+
clone_tensor_ptr(*tensor, executorch::aten::ScalarType::Int);
632+
633+
EXPECT_EQ(cloned_tensor->dim(), 2);
634+
EXPECT_EQ(cloned_tensor->size(0), 2);
635+
EXPECT_EQ(cloned_tensor->size(1), 2);
636+
EXPECT_EQ(cloned_tensor->const_data_ptr(), nullptr);
637+
EXPECT_EQ(cloned_tensor->scalar_type(), executorch::aten::ScalarType::Int);
638+
}
639+
640+
TEST_F(TensorPtrTest, CloneTensorPtrCastInvalidExpectDeath) {
641+
std::vector<float> data = {1.0f, 2.0f};
642+
auto tensor = make_tensor_ptr({2}, std::move(data));
643+
ET_EXPECT_DEATH(
644+
{
645+
auto _ =
646+
clone_tensor_ptr(*tensor, executorch::aten::ScalarType::Int);
647+
},
648+
"");
649+
}
650+
574651
TEST_F(TensorPtrTest, MakeTensorPtrFromTensorPtrInt32) {
575652
std::vector<int32_t> data = {1, 2, 3, 4};
576653
auto tensor = make_tensor_ptr({2, 2}, data);

0 commit comments

Comments
 (0)