77 */
88
99#include < executorch/kernels/portable/cpu/scalar_utils.h>
10- #include < executorch/kernels/portable/cpu/util/broadcast_util.h>
11- #include < executorch/kernels/portable/cpu/util/functional_util.h>
10+ #include < executorch/kernels/portable/cpu/util/elementwise_util.h>
1211#include < executorch/kernels/portable/cpu/util/kernel_ops_util.h>
1312#include < executorch/runtime/kernel/kernel_includes.h>
1413#include < executorch/runtime/platform/assert.h>
1514
1615namespace torch {
1716namespace executor {
1817namespace native {
19- namespace {
20-
21- template <
22- bool can_cast,
23- typename CTYPE_A,
24- typename CTYPE_B,
25- typename CTYPE_IN,
26- typename CTYPE_OUT>
27- struct AddInner ;
28-
29- template <
30- typename CTYPE_A,
31- typename CTYPE_B,
32- typename CTYPE_IN,
33- typename CTYPE_OUT>
34- struct AddInner <true , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
35- static void
36- run (const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
37- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
38- // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
39- [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
40- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
41- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
42- CTYPE_IN value = a_casted + alpha_val * b_casted;
43-
44- return static_cast <CTYPE_OUT>(value);
45- },
46- a,
47- b,
48- out);
49- }
50- };
51-
52- template <typename CTYPE_IN>
53- struct ReportCanCastBug {
54- static void run (const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
55- ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
56- }
57- };
58-
59- template <
60- typename CTYPE_A,
61- typename CTYPE_B,
62- typename CTYPE_IN,
63- typename CTYPE_OUT>
64- struct AddInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
65- : public ReportCanCastBug<CTYPE_IN> {};
66-
67- } // namespace
6818
6919Tensor& add_out (
7020 KernelRuntimeContext& ctx,
@@ -80,7 +30,9 @@ Tensor& add_out(
8030
8131 ET_KERNEL_CHECK (
8232 ctx,
83- executorch::runtime::tensor_is_realhbbf16_type (out),
33+ (executorch::runtime::tensor_is_realhbbf16_type (a) &&
34+ executorch::runtime::tensor_is_realhbbf16_type (b) &&
35+ executorch::runtime::tensor_is_realhbbf16_type (out)),
8436 InvalidArgument,
8537 out);
8638 ET_KERNEL_CHECK (
@@ -96,25 +48,20 @@ Tensor& add_out(
9648 ET_KERNEL_CHECK (
9749 ctx, check_alpha_type (alpha_type, common_type), InvalidArgument, out);
9850
99- constexpr auto name = " add.out" ;
100-
101- ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, name, CTYPE_A, [&]() {
102- ET_SWITCH_REALHBBF16_TYPES (b_type, ctx, name, CTYPE_B, [&]() {
103- using CTYPE_IN = typename torch::executor::
104- promote_types<CTYPE_A, CTYPE_B, /* half_to_float*/ true >::type;
105- ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
106- CTYPE_IN alpha_val;
107- utils::extract_scalar (alpha, &alpha_val);
108-
109- ET_SWITCH_REALHBBF16_TYPES (out_type, ctx, name, CTYPE_OUT, [&]() {
110- AddInner<
111- can_cast<CTYPE_IN, CTYPE_OUT>::value,
112- CTYPE_A,
113- CTYPE_B,
114- CTYPE_IN,
115- CTYPE_OUT>::run (a, b, alpha_val, out);
116- });
117- });
51+ static constexpr const char op_name[] = " add.out" ;
52+
53+ ET_SWITCH_REALB_TYPES (common_type, ctx, op_name, CTYPE_COMMON, [&]() {
54+ utils::apply_bitensor_elementwise_fn<CTYPE_COMMON, op_name>(
55+ [alpha](const CTYPE_COMMON val_a, const CTYPE_COMMON val_b) {
56+ CTYPE_COMMON val_alpha = utils::scalar_to<CTYPE_COMMON>(alpha);
57+ return val_a + val_alpha * val_b;
58+ },
59+ a,
60+ utils::SupportedTensorDtypes::REALHBBF16,
61+ b,
62+ utils::SupportedTensorDtypes::REALHBBF16,
63+ out,
64+ utils::SupportedTensorDtypes::REALHBBF16);
11865 });
11966
12067 return out;
@@ -138,14 +85,14 @@ Tensor& add_scalar_out(
13885
13986 ET_KERNEL_CHECK (
14087 ctx,
141- executorch::runtime::tensor_is_realhbbf16_type (out),
88+ (executorch::runtime::tensor_is_realhbbf16_type (a) &&
89+ executorch::runtime::tensor_is_realhbbf16_type (out)),
14290 InvalidArgument,
14391 out);
14492 ET_KERNEL_CHECK (
14593 ctx, tensors_have_same_dim_order (a, out), InvalidArgument, out);
14694
14795 ScalarType a_type = a.scalar_type ();
148- ScalarType b_type = utils::get_scalar_dtype (b);
14996 ScalarType alpha_type = utils::get_scalar_dtype (alpha);
15097 ScalarType common_type =
15198 utils::promote_type_with_scalar (a_type, b, /* half_to_float*/ false );
@@ -155,42 +102,23 @@ Tensor& add_scalar_out(
155102 ET_KERNEL_CHECK (
156103 ctx, check_alpha_type (alpha_type, common_type), InvalidArgument, out);
157104
158- if (common_type == ScalarType::Half) {
105+ if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16 ) {
159106 common_type = ScalarType::Float;
160107 }
161108
162- constexpr auto name = " add.Scalar_out" ;
163-
164- ET_SWITCH_REALHBBF16_TYPES (a_type, ctx, name, CTYPE_A, [&]() {
165- ET_SWITCH_SCALAR_OBJ_TYPES (b_type, ctx, name, CTYPE_B, [&]() {
166- using CTYPE_IN = typename utils::promote_type_with_scalar_type<
167- CTYPE_A,
168- CTYPE_B,
169- /* half_to_float*/ true >::type;
170- ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
171-
172- CTYPE_B b_val;
173- utils::extract_scalar (b, &b_val);
174- CTYPE_IN b_casted = static_cast <CTYPE_IN>(b_val);
175-
176- CTYPE_IN alpha_val;
177- utils::extract_scalar (alpha, &alpha_val);
178-
179- using CTYPE_OUT = typename std::conditional<
180- std::is_same<CTYPE_A, internal::F2>::value,
181- internal::F2,
182- CTYPE_IN>::type;
183-
184- apply_unary_map_fn (
185- [b_casted, alpha_val](const CTYPE_A val_a) {
186- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
187- CTYPE_IN value = a_casted + alpha_val * b_casted;
188- return static_cast <CTYPE_OUT>(value);
189- },
190- a.const_data_ptr <CTYPE_A>(),
191- out.mutable_data_ptr <CTYPE_OUT>(),
192- out.numel ());
193- });
109+ static constexpr const char op_name[] = " add.Scalar_out" ;
110+
111+ ET_SWITCH_REALB_TYPES (common_type, ctx, op_name, CTYPE_COMMON, [&]() {
112+ utils::apply_unitensor_elementwise_fn<CTYPE_COMMON, op_name>(
113+ [b, alpha](const CTYPE_COMMON val_a) {
114+ CTYPE_COMMON val_b = utils::scalar_to<CTYPE_COMMON>(b);
115+ CTYPE_COMMON val_alpha = utils::scalar_to<CTYPE_COMMON>(alpha);
116+ return val_a + val_alpha * val_b;
117+ },
118+ a,
119+ utils::SupportedTensorDtypes::REALHBBF16,
120+ out,
121+ utils::SupportedTensorDtypes::REALHBBF16);
194122 });
195123
196124 return out;
0 commit comments