|
7 | 7 | */ |
8 | 8 |
|
9 | 9 | #include <executorch/kernels/portable/cpu/scalar_utils.h> |
10 | | -#include <executorch/kernels/portable/cpu/util/broadcast_util.h> |
| 10 | +#include <executorch/kernels/portable/cpu/util/elementwise_util.h> |
11 | 11 | #include <executorch/kernels/portable/cpu/util/matmul_ops_util.h> |
12 | 12 | #include <executorch/kernels/portable/cpu/vec_ops.h> |
13 | 13 | #include <executorch/runtime/kernel/kernel_includes.h> |
@@ -53,62 +53,53 @@ Tensor& addmm_out( |
53 | 53 |
|
54 | 54 | ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); |
55 | 55 |
|
56 | | - ScalarType alpha_dtype = utils::get_scalar_dtype(alpha); |
57 | | - ScalarType beta_dtype = utils::get_scalar_dtype(beta); |
58 | | - ET_SWITCH_REAL_TYPES_AND( |
59 | | - Half, in.scalar_type(), ctx, "addmm.out", CTYPE, [&]() { |
60 | | - ET_SWITCH_SCALAR_OBJ_TYPES( |
61 | | - alpha_dtype, ctx, "addmm.out", ALPHA_T, [&]() { |
62 | | - ET_SWITCH_SCALAR_OBJ_TYPES( |
63 | | - beta_dtype, ctx, "addmm.out", BETA_T, [&]() { |
64 | | - size_t m = mat1.size(0); |
65 | | - size_t n = mat1.size(1); |
66 | | - size_t p = mat2.size(1); |
| 56 | + static constexpr const char op_name[] = "addmm.out"; |
67 | 57 |
|
68 | | - if (out.sizes() == in.sizes()) { |
69 | | - // vec_addmm assumes that no broadcasting is required. |
70 | | - vec_addmm<CTYPE, CTYPE>( |
71 | | - out.mutable_data_ptr<CTYPE>(), |
72 | | - in.const_data_ptr<CTYPE>(), |
73 | | - mat1.const_data_ptr<CTYPE>(), |
74 | | - mat2.const_data_ptr<CTYPE>(), |
75 | | - m, |
76 | | - n, |
77 | | - p, |
78 | | - convert<CTYPE>(beta.to<BETA_T>()), |
79 | | - convert<CTYPE>(alpha.to<ALPHA_T>())); |
80 | | - } else { |
81 | | - // If broadcasting is required, them compute the matmul |
82 | | - // and addition separately, using |
83 | | - // apply_binary_elementwise_fn to perform the addition |
84 | | - // while applying broadcasting |
85 | | - vec_matmul<CTYPE, CTYPE>( |
86 | | - out.mutable_data_ptr<CTYPE>(), |
87 | | - mat1.const_data_ptr<CTYPE>(), |
88 | | - mat2.const_data_ptr<CTYPE>(), |
89 | | - m, |
90 | | - n, |
91 | | - p); |
| 58 | + ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&]() { |
| 59 | + CTYPE alpha_val = utils::scalar_to<CTYPE>(alpha); |
| 60 | + CTYPE beta_val = utils::scalar_to<CTYPE>(beta); |
| 61 | + size_t m = mat1.size(0); |
| 62 | + size_t n = mat1.size(1); |
| 63 | + size_t p = mat2.size(1); |
92 | 64 |
|
93 | | - CTYPE alpha_val = convert<CTYPE>(alpha.to<ALPHA_T>()); |
94 | | - CTYPE beta_val = convert<CTYPE>(beta.to<BETA_T>()); |
95 | | - apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>( |
96 | | - [alpha_val, beta_val]( |
97 | | - const CTYPE val_a, const CTYPE val_b) { |
98 | | - CTYPE a_casted = static_cast<CTYPE>(val_a); |
99 | | - CTYPE b_casted = static_cast<CTYPE>(val_b); |
100 | | - CTYPE value = |
101 | | - a_casted * alpha_val + b_casted * beta_val; |
| 65 | + if (out.sizes() == in.sizes()) { |
| 66 | + // vec_addmm assumes that no broadcasting is required. |
| 67 | + vec_addmm<CTYPE, CTYPE>( |
| 68 | + out.mutable_data_ptr<CTYPE>(), |
| 69 | + in.const_data_ptr<CTYPE>(), |
| 70 | + mat1.const_data_ptr<CTYPE>(), |
| 71 | + mat2.const_data_ptr<CTYPE>(), |
| 72 | + m, |
| 73 | + n, |
| 74 | + p, |
| 75 | + beta_val, |
| 76 | + alpha_val); |
| 77 | + } else { |
| 78 | + // If broadcasting is required, them compute the matmul |
| 79 | + // and addition separately, using |
| 80 | + // apply_binary_elementwise_fn to perform the addition |
| 81 | + // while applying broadcasting |
| 82 | + vec_matmul<CTYPE, CTYPE>( |
| 83 | + out.mutable_data_ptr<CTYPE>(), |
| 84 | + mat1.const_data_ptr<CTYPE>(), |
| 85 | + mat2.const_data_ptr<CTYPE>(), |
| 86 | + m, |
| 87 | + n, |
| 88 | + p); |
102 | 89 |
|
103 | | - return value; |
104 | | - }, |
105 | | - out, |
106 | | - in, |
107 | | - out); |
108 | | - } |
109 | | - }); |
110 | | - }); |
111 | | - }); |
| 90 | + utils::apply_bitensor_elementwise_fn<CTYPE, op_name>( |
| 91 | + [alpha_val, beta_val](const CTYPE val_a, const CTYPE val_b) { |
| 92 | + return val_a * alpha_val + val_b * beta_val; |
| 93 | + }, |
| 94 | + ctx, |
| 95 | + out, |
| 96 | + utils::SupportedTensorDtypes::REALHBF16, |
| 97 | + in, |
| 98 | + utils::SupportedTensorDtypes::REALHBF16, |
| 99 | + out, |
| 100 | + utils::SupportedTensorDtypes::REALHBF16); |
| 101 | + } |
| 102 | + }); |
112 | 103 |
|
113 | 104 | return out; |
114 | 105 | } |
|
0 commit comments