@@ -280,7 +280,6 @@ template <typename To, typename From>
280280void convert_and_store (From f, void * dst) {
281281 *reinterpret_cast <To*>(dst) = static_cast <To>(f);
282282}
283- } // namespace internal
284283
285284template <typename CTYPE_COMMON>
286285using load_to_common_fn = CTYPE_COMMON (*)(const void *);
@@ -296,6 +295,15 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbbf16(
296295 return result;
297296}
298297
298+ template <typename CTYPE_COMMON, const char * op_name>
299+ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte (const Tensor& t) {
300+ CTYPE_COMMON (*result)(const void *) = nullptr ;
301+ ET_SWITCH_TWO_TYPES (Bool, Byte, t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
302+ result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
303+ });
304+ return result;
305+ }
306+
299307template <typename CTYPE_COMMON>
300308using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void *);
301309
@@ -310,6 +318,72 @@ get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) {
310318 return result;
311319}
312320
321+ template <typename CTYPE_COMMON, const char * op_name>
322+ store_common_to_tensor_fn<CTYPE_COMMON>
323+ get_store_common_to_tensor_fn_bool_or_byte (const Tensor& t) {
324+ void (*result)(CTYPE_COMMON, void *) = nullptr ;
325+ ET_SWITCH_TWO_TYPES (Bool, Byte,
326+ t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
327+ result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
328+ });
329+ return result;
330+ }
331+ } // namespace internal
332+
333+ enum class SupportedTensorDtypes {
334+ REALHBBF16,
335+ BOOL_OR_BYTE,
336+ SAME_AS_COMMON,
337+ };
338+
339+ namespace internal {
340+ template <typename CTYPE_COMMON, const char * op_name>
341+ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn (
342+ const Tensor& t,
343+ SupportedTensorDtypes dtypes) {
344+ switch (dtypes) {
345+ case SupportedTensorDtypes::REALHBBF16:
346+ return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
347+ case SupportedTensorDtypes::BOOL_OR_BYTE:
348+ return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
349+ case SupportedTensorDtypes::SAME_AS_COMMON: {
350+ constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
351+ ET_CHECK_MSG (
352+ t.scalar_type () == common_scalar_type,
353+ " Unhandled dtype %s for %s" ,
354+ ::executorch::runtime::toString (common_scalar_type),
355+ op_name);
356+ return internal::load_and_convert<CTYPE_COMMON, CTYPE_COMMON>;
357+ }
358+ }
359+ ET_CHECK (false );
360+ return nullptr ;
361+ }
362+
363+ template <typename CTYPE_COMMON, const char * op_name>
364+ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn (
365+ const Tensor& t,
366+ SupportedTensorDtypes dtypes) {
367+ switch (dtypes) {
368+ case SupportedTensorDtypes::REALHBBF16:
369+ return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
370+ case SupportedTensorDtypes::BOOL_OR_BYTE:
371+ return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
372+ case SupportedTensorDtypes::SAME_AS_COMMON: {
373+ constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
374+ ET_CHECK_MSG (
375+ t.scalar_type () == common_scalar_type,
376+ " Unhandled dtype %s for %s" ,
377+ ::executorch::runtime::toString (common_scalar_type),
378+ op_name);
379+ return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
380+ }
381+ }
382+ ET_CHECK (false );
383+ return nullptr ;
384+ }
385+ } // namespace internal
386+
313387/* *
314388 * Useful for binary elementwise operators. For each element of the inputs,
315389 * perform a computation and write to the corresponding element of the output.
@@ -356,33 +430,45 @@ inline void apply_binary_elementwise_fn(
356430 *
357431 * In order to mitigate build time cost (straightforwardly |CTYPE_A| *
358432 * |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun
359- * are passed as CTYPE_COMMON. We require compute_fun to return
360- * CTYPE_COMMON, and we require loading conversion functions from each
361- * input type to CTYPE_COMMON and a storing conversion from
362- * CTYPE_COMMON to CTYPE_OUT be provided. Each conversion function
363- * must take a void* pointing to an element of the corresponding
364- * tensor, load that element, and convert it to CTYPE_COMMON. The
365- * storing conversion function must have the signature
366- * void(CTYPE_COMMON, void*), convert the given element to CTYPE_OUT,
367- * and store it to the given location.
433+ * are passed as CTYPE_COMMON.
434+ *
435+ * Each tensor's supported dtypes set must be provided. The tensor
436+ * will be checked to ensure that its dtype falls into that set.
437+ *
438+ * op_name is used to support dtype selective build, as with the
439+ * ET_SWITCH family of macros. Note: because of C++17 quirks, you
440+ * can't pass a string literal for op_name. Instead, you should do the
441+ * following:
442+ *
443+ * static constexpr const char op_name[] = "my_op";
444+ * apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
368445 */
369- template <typename CTYPE_COMMON, typename Op>
446+ template <typename CTYPE_COMMON, const char * op_name, typename Op>
370447inline void apply_ternary_elementwise_fn (
371448 const Op& compute_fun,
372449 const Tensor& a,
450+ SupportedTensorDtypes a_dtypes,
373451 const Tensor& b,
452+ SupportedTensorDtypes b_dtypes,
374453 const Tensor& c,
454+ SupportedTensorDtypes c_dtypes,
375455 const Tensor& out,
376- CTYPE_COMMON (*load_a_to_common)(const void *),
377- CTYPE_COMMON (*load_b_to_common)(const void *),
378- CTYPE_COMMON (*load_c_to_common)(const void *),
379- void (*store_common_to_out)(CTYPE_COMMON, void *)) {
456+ SupportedTensorDtypes out_dtypes) {
380457 const bool a_is_broadcasted = !out.sizes ().equals (a.sizes ());
381458 const bool b_is_broadcasted = !out.sizes ().equals (b.sizes ());
382459 const bool c_is_broadcasted = !out.sizes ().equals (c.sizes ());
383460 const bool any_is_broadcasted =
384461 (a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
385462
463+ const auto load_a_to_common =
464+ internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
465+ const auto load_b_to_common =
466+ internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
467+ const auto load_c_to_common =
468+ internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(c, c_dtypes);
469+ const auto store_common_to_out =
470+ internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
471+ out, out_dtypes);
386472 const char * const data_a = reinterpret_cast <const char *>(a.const_data_ptr ());
387473 const char * const data_b = reinterpret_cast <const char *>(b.const_data_ptr ());
388474 const char * const data_c = reinterpret_cast <const char *>(c.const_data_ptr ());
0 commit comments