@@ -51,6 +51,15 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbf16(
5151 return result;
5252}
5353
54+ template <typename CTYPE_COMMON, const char * op_name>
55+ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realh (const Tensor& t) {
56+ CTYPE_COMMON (*result)(const void *) = nullptr ;
57+ ET_SWITCH_REALH_TYPES (t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
58+ result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
59+ });
60+ return result;
61+ }
62+
5463template <typename CTYPE_COMMON, const char * op_name>
5564load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_floathbf16 (
5665 const Tensor& t) {
@@ -72,6 +81,16 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_intb(const Tensor& t) {
7281 return result;
7382}
7483
84+ template <typename CTYPE_COMMON, const char * op_name>
85+ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool (const Tensor& t) {
86+ ET_CHECK_MSG (
87+ t.scalar_type () == ScalarType::Bool,
88+ " Unhandled dtype %s for %s" ,
89+ ::executorch::runtime::toString (t.scalar_type()),
90+ op_name);
91+ return internal::load_and_convert<CTYPE_COMMON, bool >;
92+ }
93+
7594template <typename CTYPE_COMMON, const char * op_name>
7695load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte (
7796 const Tensor& t) {
@@ -137,6 +156,16 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_realhbf16(
137156 return result;
138157}
139158
159+ template <typename CTYPE_COMMON, const char * op_name>
160+ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_realh (
161+ const Tensor& t) {
162+ void (*result)(CTYPE_COMMON, void *) = nullptr ;
163+ ET_SWITCH_REALH_TYPES (t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
164+ result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
165+ });
166+ return result;
167+ }
168+
140169template <typename CTYPE_COMMON, const char * op_name>
141170store_common_to_tensor_fn<CTYPE_COMMON>
142171get_store_common_to_tensor_fn_floathbf16 (const Tensor& t) {
@@ -159,6 +188,17 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_intb(
159188 return result;
160189}
161190
191+ template <typename CTYPE_COMMON, const char * op_name>
192+ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_bool (
193+ const Tensor& t) {
194+ ET_CHECK_MSG (
195+ t.scalar_type () == ScalarType::Bool,
196+ " Unhandled dtype %s for %s" ,
197+ ::executorch::runtime::toString (t.scalar_type()),
198+ op_name);
199+ return internal::convert_and_store<bool , CTYPE_COMMON>;
200+ }
201+
162202template <typename CTYPE_COMMON, const char * op_name>
163203store_common_to_tensor_fn<CTYPE_COMMON>
164204get_store_common_to_tensor_fn_bool_or_byte (const Tensor& t) {
@@ -206,8 +246,10 @@ get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
206246enum class SupportedTensorDtypes {
207247 REALHBBF16,
208248 REALHBF16,
249+ REALH,
209250 FLOATHBF16,
210251 INTB,
252+ BOOL,
211253 BOOL_OR_BYTE,
212254 SAME_AS_COMPUTE,
213255 SAME_AS_COMMON,
@@ -224,10 +266,14 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn(
224266 return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
225267 case SupportedTensorDtypes::REALHBF16:
226268 return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
269+ case SupportedTensorDtypes::REALH:
270+ return get_load_to_common_fn_realh<CTYPE_COMMON, op_name>(t);
227271 case SupportedTensorDtypes::FLOATHBF16:
228272 return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
229273 case SupportedTensorDtypes::INTB:
230274 return get_load_to_common_fn_intb<CTYPE_COMMON, op_name>(t);
275+ case SupportedTensorDtypes::BOOL:
276+ return get_load_to_common_fn_bool<CTYPE_COMMON, op_name>(t);
231277 case SupportedTensorDtypes::BOOL_OR_BYTE:
232278 return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
233279 case SupportedTensorDtypes::SAME_AS_COMPUTE:
@@ -248,10 +294,14 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
248294 return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
249295 case SupportedTensorDtypes::REALHBF16:
250296 return get_store_common_to_tensor_fn_realhbf16<CTYPE_COMMON, op_name>(t);
297+ case SupportedTensorDtypes::REALH:
298+ return get_store_common_to_tensor_fn_realh<CTYPE_COMMON, op_name>(t);
251299 case SupportedTensorDtypes::FLOATHBF16:
252300 return get_store_common_to_tensor_fn_floathbf16<CTYPE_COMMON, op_name>(t);
253301 case SupportedTensorDtypes::INTB:
254302 return get_store_common_to_tensor_fn_intb<CTYPE_COMMON, op_name>(t);
303+ case SupportedTensorDtypes::BOOL:
304+ return get_store_common_to_tensor_fn_bool<CTYPE_COMMON, op_name>(t);
255305 case SupportedTensorDtypes::BOOL_OR_BYTE:
256306 return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(
257307 t);
0 commit comments