@@ -279,7 +279,10 @@ class TensorFactory {
279279 t = empty_strided (sizes, strides);
280280 }
281281 if (t.nbytes () > 0 ) {
282- memcpy (t.template data <true_ctype>(), data.data (), t.nbytes ());
282+ std::transform (
283+ data.begin (), data.end (), t.template data <true_ctype>(), [](auto x) {
284+ return static_cast <true_ctype>(x);
285+ });
283286 }
284287 return t;
285288 }
@@ -319,7 +322,10 @@ class TensorFactory {
319322 t = empty_strided (sizes, strides);
320323 }
321324 if (t.nbytes () > 0 ) {
322- memcpy (t.template data <true_ctype>(), data.data (), t.nbytes ());
325+ std::transform (
326+ data.begin (), data.end (), t.template data <true_ctype>(), [](auto x) {
327+ return static_cast <true_ctype>(x);
328+ });
323329 }
324330 return t;
325331 }
@@ -721,6 +727,13 @@ class TensorFactory {
721727 */
722728 using ctype = typename internal::ScalarTypeToCppTypeWrapper<DTYPE>::ctype;
723729
730+ /* *
731+ * The official C type for the scalar type. Used when accessing elements
732+ * of a constructed Tensor.
733+ */
734+ using true_ctype =
735+ typename executorch::runtime::ScalarTypeToCppType<DTYPE>::type;
736+
724737 TensorFactory () = default ;
725738
726739 /* *
@@ -1019,7 +1032,14 @@ class TensorFactory {
10191032 data_.data(),
10201033 dim_order_.data(),
10211034 strides_.data(),
1022- dynamism) {}
1035+ dynamism) {
1036+ // The only valid values for bool are 0 and 1; coerce!
1037+ if constexpr (std::is_same_v<true_ctype, bool >) {
1038+ for (auto & x : data_) {
1039+ x = static_cast <true_ctype>(x);
1040+ }
1041+ }
1042+ }
10231043
10241044 std::vector<int32_t > sizes_;
10251045 std::vector<ctype> data_;
0 commit comments