Skip to content

Commit 2c8d070

Browse files
[ET][Portable][Build Size] Reduce build size of op_addmm
Pull Request resolved: #6018 200 K -> 30 K ghstack-source-id: 246944188 @exported-using-ghexport Differential Revision: [D63994874](https://our.internmc.facebook.com/intern/diff/D63994874/)
1 parent c34eb2b commit 2c8d070

File tree

2 files changed

+47
-54
lines changed

2 files changed

+47
-54
lines changed

kernels/portable/cpu/op_addmm.cpp

Lines changed: 45 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
*/
88

99
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1111
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
1212
#include <executorch/kernels/portable/cpu/vec_ops.h>
1313
#include <executorch/runtime/kernel/kernel_includes.h>
@@ -53,62 +53,53 @@ Tensor& addmm_out(
5353

5454
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
5555

56-
ScalarType alpha_dtype = utils::get_scalar_dtype(alpha);
57-
ScalarType beta_dtype = utils::get_scalar_dtype(beta);
58-
ET_SWITCH_REAL_TYPES_AND(
59-
Half, in.scalar_type(), ctx, "addmm.out", CTYPE, [&]() {
60-
ET_SWITCH_SCALAR_OBJ_TYPES(
61-
alpha_dtype, ctx, "addmm.out", ALPHA_T, [&]() {
62-
ET_SWITCH_SCALAR_OBJ_TYPES(
63-
beta_dtype, ctx, "addmm.out", BETA_T, [&]() {
64-
size_t m = mat1.size(0);
65-
size_t n = mat1.size(1);
66-
size_t p = mat2.size(1);
56+
static constexpr const char op_name[] = "addmm.out";
6757

68-
if (out.sizes() == in.sizes()) {
69-
// vec_addmm assumes that no broadcasting is required.
70-
vec_addmm<CTYPE, CTYPE>(
71-
out.mutable_data_ptr<CTYPE>(),
72-
in.const_data_ptr<CTYPE>(),
73-
mat1.const_data_ptr<CTYPE>(),
74-
mat2.const_data_ptr<CTYPE>(),
75-
m,
76-
n,
77-
p,
78-
convert<CTYPE>(beta.to<BETA_T>()),
79-
convert<CTYPE>(alpha.to<ALPHA_T>()));
80-
} else {
81-
// If broadcasting is required, them compute the matmul
82-
// and addition separately, using
83-
// apply_binary_elementwise_fn to perform the addition
84-
// while applying broadcasting
85-
vec_matmul<CTYPE, CTYPE>(
86-
out.mutable_data_ptr<CTYPE>(),
87-
mat1.const_data_ptr<CTYPE>(),
88-
mat2.const_data_ptr<CTYPE>(),
89-
m,
90-
n,
91-
p);
58+
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&]() {
59+
CTYPE alpha_val = utils::scalar_to<CTYPE>(alpha);
60+
CTYPE beta_val = utils::scalar_to<CTYPE>(beta);
61+
size_t m = mat1.size(0);
62+
size_t n = mat1.size(1);
63+
size_t p = mat2.size(1);
9264

93-
CTYPE alpha_val = convert<CTYPE>(alpha.to<ALPHA_T>());
94-
CTYPE beta_val = convert<CTYPE>(beta.to<BETA_T>());
95-
apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>(
96-
[alpha_val, beta_val](
97-
const CTYPE val_a, const CTYPE val_b) {
98-
CTYPE a_casted = static_cast<CTYPE>(val_a);
99-
CTYPE b_casted = static_cast<CTYPE>(val_b);
100-
CTYPE value =
101-
a_casted * alpha_val + b_casted * beta_val;
65+
if (out.sizes() == in.sizes()) {
66+
// vec_addmm assumes that no broadcasting is required.
67+
vec_addmm<CTYPE, CTYPE>(
68+
out.mutable_data_ptr<CTYPE>(),
69+
in.const_data_ptr<CTYPE>(),
70+
mat1.const_data_ptr<CTYPE>(),
71+
mat2.const_data_ptr<CTYPE>(),
72+
m,
73+
n,
74+
p,
75+
beta_val,
76+
alpha_val);
77+
} else {
78+
// If broadcasting is required, them compute the matmul
79+
// and addition separately, using
80+
// apply_binary_elementwise_fn to perform the addition
81+
// while applying broadcasting
82+
vec_matmul<CTYPE, CTYPE>(
83+
out.mutable_data_ptr<CTYPE>(),
84+
mat1.const_data_ptr<CTYPE>(),
85+
mat2.const_data_ptr<CTYPE>(),
86+
m,
87+
n,
88+
p);
10289

103-
return value;
104-
},
105-
out,
106-
in,
107-
out);
108-
}
109-
});
110-
});
111-
});
90+
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
91+
[alpha_val, beta_val](const CTYPE val_a, const CTYPE val_b) {
92+
return val_a * alpha_val + val_b * beta_val;
93+
},
94+
ctx,
95+
out,
96+
utils::SupportedTensorDtypes::REALHBF16,
97+
in,
98+
utils::SupportedTensorDtypes::REALHBF16,
99+
out,
100+
utils::SupportedTensorDtypes::REALHBF16);
101+
}
102+
});
112103

113104
return out;
114105
}

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ ATEN_OPS = (
224224
name = "op_addmm",
225225
deps = [
226226
"//executorch/kernels/portable/cpu/util:broadcast_util",
227+
"//executorch/kernels/portable/cpu/util:dtype_util",
228+
"//executorch/kernels/portable/cpu/util:elementwise_util",
227229
"//executorch/kernels/portable/cpu/util:matmul_ops_util",
228230
":scalar_utils",
229231
":vec_ops",

0 commit comments

Comments
 (0)