Skip to content

Commit 4153371

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Introduce uni-tensor and bi-tensor equivalents (#6006)
Summary: Pull Request resolved: #6006 This introduces unitensor and bitensor equivalents to the tritensor template introduced by swolchok. They are applied to add.Scalar_out and add.out respectively, reducing op_add's build size. The unitensor template is applied to clamp.Scalar_out, reducing op_clamp's build size further. Build size reduction: - add: 484K -> 21K - clamp: 119K -> 28K ghstack-source-id: 246919722 exported-using-ghexport Reviewed By: swolchok Differential Revision: D63838076 fbshipit-source-id: 564d4435e51302856611b309ac0ef957baee3a43
1 parent a79caab commit 4153371

File tree

4 files changed

+163
-146
lines changed

4 files changed

+163
-146
lines changed

kernels/portable/cpu/op_add.cpp

Lines changed: 34 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -7,64 +7,14 @@
77
*/
88

99
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
11-
#include <executorch/kernels/portable/cpu/util/functional_util.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1211
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
1312
#include <executorch/runtime/kernel/kernel_includes.h>
1413
#include <executorch/runtime/platform/assert.h>
1514

1615
namespace torch {
1716
namespace executor {
1817
namespace native {
19-
namespace {
20-
21-
template <
22-
bool can_cast,
23-
typename CTYPE_A,
24-
typename CTYPE_B,
25-
typename CTYPE_IN,
26-
typename CTYPE_OUT>
27-
struct AddInner;
28-
29-
template <
30-
typename CTYPE_A,
31-
typename CTYPE_B,
32-
typename CTYPE_IN,
33-
typename CTYPE_OUT>
34-
struct AddInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
35-
static void
36-
run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
37-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
38-
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
39-
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
40-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
41-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
42-
CTYPE_IN value = a_casted + alpha_val * b_casted;
43-
44-
return static_cast<CTYPE_OUT>(value);
45-
},
46-
a,
47-
b,
48-
out);
49-
}
50-
};
51-
52-
template <typename CTYPE_IN>
53-
struct ReportCanCastBug {
54-
static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
55-
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
56-
}
57-
};
58-
59-
template <
60-
typename CTYPE_A,
61-
typename CTYPE_B,
62-
typename CTYPE_IN,
63-
typename CTYPE_OUT>
64-
struct AddInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
65-
: public ReportCanCastBug<CTYPE_IN> {};
66-
67-
} // namespace
6818

6919
Tensor& add_out(
7020
KernelRuntimeContext& ctx,
@@ -80,7 +30,9 @@ Tensor& add_out(
8030

8131
ET_KERNEL_CHECK(
8232
ctx,
83-
executorch::runtime::tensor_is_realhbbf16_type(out),
33+
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
34+
executorch::runtime::tensor_is_realhbbf16_type(b) &&
35+
executorch::runtime::tensor_is_realhbbf16_type(out)),
8436
InvalidArgument,
8537
out);
8638
ET_KERNEL_CHECK(
@@ -96,25 +48,20 @@ Tensor& add_out(
9648
ET_KERNEL_CHECK(
9749
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);
9850

99-
constexpr auto name = "add.out";
100-
101-
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
102-
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
103-
using CTYPE_IN = typename torch::executor::
104-
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
105-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
106-
CTYPE_IN alpha_val;
107-
utils::extract_scalar(alpha, &alpha_val);
108-
109-
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
110-
AddInner<
111-
can_cast<CTYPE_IN, CTYPE_OUT>::value,
112-
CTYPE_A,
113-
CTYPE_B,
114-
CTYPE_IN,
115-
CTYPE_OUT>::run(a, b, alpha_val, out);
116-
});
117-
});
51+
static constexpr const char op_name[] = "add.out";
52+
53+
ET_SWITCH_REALB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
54+
utils::apply_bitensor_elementwise_fn<CTYPE_COMMON, op_name>(
55+
[alpha](const CTYPE_COMMON val_a, const CTYPE_COMMON val_b) {
56+
CTYPE_COMMON val_alpha = utils::scalar_to<CTYPE_COMMON>(alpha);
57+
return val_a + val_alpha * val_b;
58+
},
59+
a,
60+
utils::SupportedTensorDtypes::REALHBBF16,
61+
b,
62+
utils::SupportedTensorDtypes::REALHBBF16,
63+
out,
64+
utils::SupportedTensorDtypes::REALHBBF16);
11865
});
11966

12067
return out;
@@ -138,14 +85,14 @@ Tensor& add_scalar_out(
13885

13986
ET_KERNEL_CHECK(
14087
ctx,
141-
executorch::runtime::tensor_is_realhbbf16_type(out),
88+
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
89+
executorch::runtime::tensor_is_realhbbf16_type(out)),
14290
InvalidArgument,
14391
out);
14492
ET_KERNEL_CHECK(
14593
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
14694

14795
ScalarType a_type = a.scalar_type();
148-
ScalarType b_type = utils::get_scalar_dtype(b);
14996
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
15097
ScalarType common_type =
15198
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
@@ -155,42 +102,23 @@ Tensor& add_scalar_out(
155102
ET_KERNEL_CHECK(
156103
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);
157104

158-
if (common_type == ScalarType::Half) {
105+
if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
159106
common_type = ScalarType::Float;
160107
}
161108

162-
constexpr auto name = "add.Scalar_out";
163-
164-
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
165-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
166-
using CTYPE_IN = typename utils::promote_type_with_scalar_type<
167-
CTYPE_A,
168-
CTYPE_B,
169-
/*half_to_float*/ true>::type;
170-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
171-
172-
CTYPE_B b_val;
173-
utils::extract_scalar(b, &b_val);
174-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
175-
176-
CTYPE_IN alpha_val;
177-
utils::extract_scalar(alpha, &alpha_val);
178-
179-
using CTYPE_OUT = typename std::conditional<
180-
std::is_same<CTYPE_A, internal::F2>::value,
181-
internal::F2,
182-
CTYPE_IN>::type;
183-
184-
apply_unary_map_fn(
185-
[b_casted, alpha_val](const CTYPE_A val_a) {
186-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
187-
CTYPE_IN value = a_casted + alpha_val * b_casted;
188-
return static_cast<CTYPE_OUT>(value);
189-
},
190-
a.const_data_ptr<CTYPE_A>(),
191-
out.mutable_data_ptr<CTYPE_OUT>(),
192-
out.numel());
193-
});
109+
static constexpr const char op_name[] = "add.Scalar_out";
110+
111+
ET_SWITCH_REALB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
112+
utils::apply_unitensor_elementwise_fn<CTYPE_COMMON, op_name>(
113+
[b, alpha](const CTYPE_COMMON val_a) {
114+
CTYPE_COMMON val_b = utils::scalar_to<CTYPE_COMMON>(b);
115+
CTYPE_COMMON val_alpha = utils::scalar_to<CTYPE_COMMON>(alpha);
116+
return val_a + val_alpha * val_b;
117+
},
118+
a,
119+
utils::SupportedTensorDtypes::REALHBBF16,
120+
out,
121+
utils::SupportedTensorDtypes::REALHBBF16);
194122
});
195123

196124
return out;

kernels/portable/cpu/op_clamp.cpp

Lines changed: 19 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1515
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
16-
#include <executorch/kernels/portable/cpu/util/functional_util.h>
1716
#include <executorch/kernels/portable/cpu/util/math_util.h>
1817
#include <executorch/runtime/kernel/kernel_includes.h>
1918

@@ -122,43 +121,26 @@ Tensor& clamp_out(
122121

123122
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
124123

125-
ET_SWITCH_REALH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() {
126-
// Extract optional min value
127-
CTYPE_OUT min = 0;
128-
if (has_min) {
129-
ET_SWITCH_SCALAR_OBJ_TYPES(min_type, ctx, "clamp", CTYPE_MIN, [&]() {
130-
CTYPE_MIN min_val = 0;
131-
utils::extract_scalar(min_opt.value(), &min_val);
132-
min = static_cast<CTYPE_OUT>(min_val);
133-
});
134-
}
135-
136-
// Extract optional max value
137-
CTYPE_OUT max = 0;
138-
if (has_max) {
139-
ET_SWITCH_SCALAR_OBJ_TYPES(max_type, ctx, "clamp", CTYPE_MAX, [&]() {
140-
CTYPE_MAX max_val = 0;
141-
utils::extract_scalar(max_opt.value(), &max_val);
142-
max = static_cast<CTYPE_OUT>(max_val);
143-
});
144-
}
124+
static constexpr const char op_name[] = "clamp.out";
145125

146-
ET_SWITCH_REALHB_TYPES(in_type, ctx, "clamp", CTYPE_IN, [&]() {
147-
apply_unary_map_fn(
148-
[has_min, min, has_max, max](const CTYPE_IN val_in) {
149-
CTYPE_OUT val_out = static_cast<CTYPE_OUT>(val_in);
150-
if (has_min) {
151-
val_out = utils::max_override(val_out, min);
152-
}
153-
if (has_max) {
154-
val_out = utils::min_override(val_out, max);
155-
}
156-
return val_out;
157-
},
158-
in.const_data_ptr<CTYPE_IN>(),
159-
out.mutable_data_ptr<CTYPE_OUT>(),
160-
in.numel());
161-
});
126+
ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
127+
utils::apply_unitensor_elementwise_fn<CTYPE_COMMON, op_name>(
128+
[has_min, min_opt, has_max, max_opt](const CTYPE_COMMON val_in) {
129+
CTYPE_COMMON val_out = val_in;
130+
if (has_min) {
131+
val_out = utils::max_override(
132+
val_out, utils::scalar_to<CTYPE_COMMON>(min_opt.value()));
133+
}
134+
if (has_max) {
135+
val_out = utils::min_override(
136+
val_out, utils::scalar_to<CTYPE_COMMON>(max_opt.value()));
137+
}
138+
return val_out;
139+
},
140+
in,
141+
utils::SupportedTensorDtypes::REALHBBF16,
142+
out,
143+
utils::SupportedTensorDtypes::REALHBBF16);
162144
});
163145

164146
return out;

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,33 @@ namespace executor {
1616
namespace native {
1717
namespace utils {
1818

19+
/*
20+
* Convert Scalar to C++ type
21+
*/
22+
23+
template <typename T>
24+
T scalar_to(const Scalar& s) {
25+
if (s.isBoolean()) {
26+
return static_cast<T>(s.to<bool>());
27+
} else if (s.isFloatingPoint()) {
28+
return static_cast<T>(s.to<double>());
29+
} else {
30+
return static_cast<T>(s.to<int64_t>());
31+
}
32+
}
33+
34+
template <>
35+
inline double scalar_to<double>(const Scalar& s) {
36+
return s.isFloatingPoint() ? s.to<double>()
37+
: static_cast<double>(s.to<int64_t>());
38+
}
39+
40+
template <>
41+
inline int64_t scalar_to<int64_t>(const Scalar& s) {
42+
return s.isFloatingPoint() ? static_cast<int64_t>(s.to<double>())
43+
: s.to<int64_t>();
44+
}
45+
1946
namespace internal {
2047

2148
template <typename To, typename From>
@@ -139,6 +166,86 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
139166

140167
} // namespace internal
141168

169+
template <typename CTYPE_COMMON, const char* op_name, typename Op>
170+
inline void apply_unitensor_elementwise_fn(
171+
const Op& compute_fun,
172+
const Tensor& a,
173+
SupportedTensorDtypes a_dtypes,
174+
const Tensor& out,
175+
SupportedTensorDtypes out_dtypes) {
176+
const auto load_a_to_common =
177+
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
178+
const auto store_common_to_out =
179+
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
180+
out, out_dtypes);
181+
const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
182+
const auto a_element_size = a.element_size();
183+
const auto out_element_size = out.element_size();
184+
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
185+
186+
auto out_numel = out.numel();
187+
for (size_t i = 0; i < out_numel; ++i) {
188+
auto result = compute_fun(load_a_to_common(&data_a[i * a_element_size]));
189+
store_common_to_out(result, &data_out[i * out_element_size]);
190+
}
191+
}
192+
193+
/**
194+
* Useful for bi-tensor elementwise operators. For each element of the inputs,
195+
* perform a computation and write to the corresponding element of the output.
196+
* Tensor broadcasting is applied wherever it is required.
197+
*/
198+
template <typename CTYPE_COMMON, const char* op_name, typename Op>
199+
inline void apply_bitensor_elementwise_fn(
200+
const Op& compute_fun,
201+
const Tensor& a,
202+
SupportedTensorDtypes a_dtypes,
203+
const Tensor& b,
204+
SupportedTensorDtypes b_dtypes,
205+
const Tensor& out,
206+
SupportedTensorDtypes out_dtypes) {
207+
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
208+
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
209+
const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted);
210+
211+
const auto load_a_to_common =
212+
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
213+
const auto load_b_to_common =
214+
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
215+
const auto store_common_to_out =
216+
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
217+
out, out_dtypes);
218+
const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
219+
const char* const data_b = reinterpret_cast<const char*>(b.const_data_ptr());
220+
const auto a_element_size = a.element_size();
221+
const auto b_element_size = b.element_size();
222+
const auto out_element_size = out.element_size();
223+
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
224+
225+
auto out_numel = out.numel();
226+
for (size_t i = 0; i < out_numel; ++i) {
227+
size_t a_linear_index = i;
228+
size_t b_linear_index = i;
229+
230+
if (any_is_broadcasted) {
231+
size_t out_indexes[kTensorDimensionLimit];
232+
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);
233+
234+
if (a_is_broadcasted) {
235+
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
236+
}
237+
if (b_is_broadcasted) {
238+
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
239+
}
240+
}
241+
242+
auto result = compute_fun(
243+
load_a_to_common(&data_a[a_linear_index * a_element_size]),
244+
load_b_to_common(&data_b[b_linear_index * b_element_size]));
245+
store_common_to_out(result, &data_out[i * out_element_size]);
246+
}
247+
}
248+
142249
/**
143250
* Useful for tri-tensor elementwise operators. For each element of the inputs,
144251
* perform a computation and write to the corresponding element of the output.
@@ -194,7 +301,8 @@ inline void apply_tritensor_elementwise_fn(
194301
const auto out_element_size = out.element_size();
195302
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
196303

197-
for (size_t i = 0; i < out.numel(); ++i) {
304+
auto out_numel = out.numel();
305+
for (size_t i = 0; i < out_numel; ++i) {
198306
size_t a_linear_index = i;
199307
size_t b_linear_index = i;
200308
size_t c_linear_index = i;

0 commit comments

Comments
 (0)