Skip to content

Commit 0d910f0

Browse files
committed
Update
[ghstack-poisoned]
1 parent 2411c99 commit 0d910f0

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

runtime/core/exec_aten/testing_util/tensor_factory.h

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,10 @@ class TensorFactory {
279279
t = empty_strided(sizes, strides);
280280
}
281281
if (t.nbytes() > 0) {
282-
memcpy(t.template data<true_ctype>(), data.data(), t.nbytes());
282+
std::transform(
283+
data.begin(), data.end(), t.template data<true_ctype>(), [](auto x) {
284+
return static_cast<true_ctype>(x);
285+
});
283286
}
284287
return t;
285288
}
@@ -319,7 +322,10 @@ class TensorFactory {
319322
t = empty_strided(sizes, strides);
320323
}
321324
if (t.nbytes() > 0) {
322-
memcpy(t.template data<true_ctype>(), data.data(), t.nbytes());
325+
std::transform(
326+
data.begin(), data.end(), t.template data<true_ctype>(), [](auto x) {
327+
return static_cast<true_ctype>(x);
328+
});
323329
}
324330
return t;
325331
}
@@ -721,6 +727,13 @@ class TensorFactory {
721727
*/
722728
using ctype = typename internal::ScalarTypeToCppTypeWrapper<DTYPE>::ctype;
723729

730+
/**
731+
* The official C type for the scalar type. Used when accessing elements
732+
* of a constructed Tensor.
733+
*/
734+
using true_ctype =
735+
typename executorch::runtime::ScalarTypeToCppType<DTYPE>::type;
736+
724737
TensorFactory() = default;
725738

726739
/**
@@ -1019,7 +1032,14 @@ class TensorFactory {
10191032
data_.data(),
10201033
dim_order_.data(),
10211034
strides_.data(),
1022-
dynamism) {}
1035+
dynamism) {
1036+
// The only valid values for bool are 0 and 1; coerce!
1037+
if constexpr (std::is_same_v<true_ctype, bool>) {
1038+
for (auto& x : data_) {
1039+
x = static_cast<true_ctype>(x);
1040+
}
1041+
}
1042+
}
10231043

10241044
std::vector<int32_t> sizes_;
10251045
std::vector<ctype> data_;

0 commit comments

Comments
 (0)