Skip to content

Commit a79caab

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Move to elementwise_utils as apply_tritensor_elementwise_fn (#6005)
Summary: Pull Request resolved: #6005 swolchok's technique is superior to the TensorReader/TensorWriter approach I introduced in D63703174. So, I am rewriting my build size reduction stack on top of his approach. Superior how? - It should lead to smaller overall build size. Current measurements indicate this. Complete data will be published after the stack is complete. - It is better suited for dtype selective build, since it passes the op name to all of the ET_SWITCHes involved. - It is more performant. Current measurements to clamp.Tensor_out indicate this. Note that in the data below, my stack is marginally more performant for the vanilla case (no broadcast & all dtypes equal), but this is only because I added a "fast path" in my code for such vanilla case, which can be trivially added to Scott's approach as well. It is more relevant to compare numbers for mixed dtype or broadcasting. ``` Baseline clamp.Tensor_out no broadcast float: 25451 [23423 - 28839] microseconds clamp.Tensor_out no broadcast double: 25461 [23377 - 50940] microseconds clamp.Tensor_out no broadcast mixed dtype: 23367 [21353 - 27022] microseconds clamp.Tensor_out broadcast: 702529 [679667 - 742005] microseconds Manuel C clamp.Tensor_out no broadcast float: 22919 [21333 - 27140] microseconds clamp.Tensor_out no broadcast double: 23095 [21472 - 27462] microseconds clamp.Tensor_out no broadcast mixed dtype: 35042 [32875 - 42491] microseconds clamp.Tensor_out broadcast: 936541 [916437 - 971499] microseconds Scott W clamp.Tensor_out no broadcast float: 28263 [26458 - 32832] microseconds clamp.Tensor_out no broadcast double: 27442 [25548 - 39417] microseconds clamp.Tensor_out no broadcast mixed dtype: 25592 [23620 - 30148] microseconds clamp.Tensor_out broadcast: 695399 [674244 - 738919] microseconds ``` Build size reduction after Scott's diffs touching clamp.Tensor_out and where.self_out: - clamp: 7.42 MB -> 119 KB - where: 106 KB -> 16 KB ghstack-source-id: 246919714 exported-using-ghexport Reviewed By: malfet, swolchok Differential Revision: D63838072 fbshipit-source-id: 1710bf791bf6866bd4e2cbc0c1409004d50dac8b
1 parent d271825 commit a79caab

File tree

6 files changed

+271
-177
lines changed

6 files changed

+271
-177
lines changed

kernels/portable/cpu/op_clamp.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include <limits>
1313

1414
#include <executorch/kernels/portable/cpu/scalar_utils.h>
15-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
15+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1616
#include <executorch/kernels/portable/cpu/util/functional_util.h>
1717
#include <executorch/kernels/portable/cpu/util/math_util.h>
1818
#include <executorch/runtime/kernel/kernel_includes.h>
@@ -215,7 +215,7 @@ Tensor& clamp_tensor_out(
215215
static constexpr const char op_name[] = "clamp.Tensor_out";
216216

217217
ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
218-
apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>(
218+
utils::apply_tritensor_elementwise_fn<CTYPE_COMMON, op_name>(
219219
[has_min, has_max](
220220
const CTYPE_COMMON val_in,
221221
const CTYPE_COMMON val_min,
@@ -230,13 +230,13 @@ Tensor& clamp_tensor_out(
230230
return val_out;
231231
},
232232
in,
233-
SupportedTensorDtypes::REALHBBF16,
233+
utils::SupportedTensorDtypes::REALHBBF16,
234234
min,
235-
SupportedTensorDtypes::REALHBBF16,
235+
utils::SupportedTensorDtypes::REALHBBF16,
236236
max,
237-
SupportedTensorDtypes::REALHBBF16,
237+
utils::SupportedTensorDtypes::REALHBBF16,
238238
out,
239-
SupportedTensorDtypes::REALHBBF16);
239+
utils::SupportedTensorDtypes::REALHBBF16);
240240
});
241241

242242
return out;

kernels/portable/cpu/op_where.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
10-
#include <executorch/kernels/portable/cpu/util/functional_util.h>
9+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1110
#include <executorch/runtime/kernel/kernel_includes.h>
1211

1312
namespace torch {
@@ -44,19 +43,20 @@ Tensor& where_out(
4443
cond_type == ScalarType::Bool || cond_type == ScalarType::Byte,
4544
"Unhandled dtype %s for where.self_out",
4645
torch::executor::toString(cond_type));
46+
4747
ET_SWITCH_REALHBBF16_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
48-
apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>(
48+
utils::apply_tritensor_elementwise_fn<CTYPE_COMMON, op_name>(
4949
[](const CTYPE_COMMON val_a,
5050
const CTYPE_COMMON val_b,
5151
const CTYPE_COMMON val_c) { return val_c ? val_a : val_b; },
5252
a,
53-
SupportedTensorDtypes::REALHBBF16,
53+
utils::SupportedTensorDtypes::REALHBBF16,
5454
b,
55-
SupportedTensorDtypes::REALHBBF16,
55+
utils::SupportedTensorDtypes::REALHBBF16,
5656
cond,
57-
SupportedTensorDtypes::BOOL_OR_BYTE,
57+
utils::SupportedTensorDtypes::BOOL_OR_BYTE,
5858
out,
59-
SupportedTensorDtypes::SAME_AS_COMMON);
59+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
6060
});
6161

6262
return out;

kernels/portable/cpu/util/broadcast_util.h

Lines changed: 13 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -270,125 +270,6 @@ size_t linearize_access_indexes(
270270
// Mapping with broadcasting
271271
//
272272

273-
namespace internal {
274-
template <typename To, typename From>
275-
To load_and_convert(const void* fromPtr) {
276-
return static_cast<To>(*reinterpret_cast<const From*>(fromPtr));
277-
}
278-
279-
template <typename To, typename From>
280-
void convert_and_store(From f, void* dst) {
281-
*reinterpret_cast<To*>(dst) = static_cast<To>(f);
282-
}
283-
284-
template <typename CTYPE_COMMON>
285-
using load_to_common_fn = CTYPE_COMMON (*)(const void*);
286-
287-
template <typename CTYPE_COMMON, const char* op_name>
288-
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbbf16(
289-
const Tensor& t) {
290-
CTYPE_COMMON (*result)(const void*) = nullptr;
291-
ET_SWITCH_REALHBBF16_TYPES(
292-
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
293-
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
294-
});
295-
return result;
296-
}
297-
298-
template <typename CTYPE_COMMON, const char* op_name>
299-
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte(
300-
const Tensor& t) {
301-
CTYPE_COMMON (*result)(const void*) = nullptr;
302-
ET_SWITCH_TWO_TYPES(
303-
Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
304-
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
305-
});
306-
return result;
307-
}
308-
309-
template <typename CTYPE_COMMON>
310-
using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*);
311-
312-
template <typename CTYPE_COMMON, const char* op_name>
313-
store_common_to_tensor_fn<CTYPE_COMMON>
314-
get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) {
315-
void (*result)(CTYPE_COMMON, void*) = nullptr;
316-
ET_SWITCH_REALHBBF16_TYPES(
317-
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
318-
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
319-
});
320-
return result;
321-
}
322-
323-
template <typename CTYPE_COMMON, const char* op_name>
324-
store_common_to_tensor_fn<CTYPE_COMMON>
325-
get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
326-
void (*result)(CTYPE_COMMON, void*) = nullptr;
327-
ET_SWITCH_TWO_TYPES(
328-
Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
329-
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
330-
});
331-
return result;
332-
}
333-
} // namespace internal
334-
335-
enum class SupportedTensorDtypes {
336-
REALHBBF16,
337-
BOOL_OR_BYTE,
338-
SAME_AS_COMMON,
339-
};
340-
341-
namespace internal {
342-
template <typename CTYPE_COMMON, const char* op_name>
343-
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn(
344-
const Tensor& t,
345-
SupportedTensorDtypes dtypes) {
346-
switch (dtypes) {
347-
case SupportedTensorDtypes::REALHBBF16:
348-
return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
349-
case SupportedTensorDtypes::BOOL_OR_BYTE:
350-
return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
351-
case SupportedTensorDtypes::SAME_AS_COMMON: {
352-
constexpr auto common_scalar_type =
353-
CppTypeToScalarType<CTYPE_COMMON>::value;
354-
ET_CHECK_MSG(
355-
t.scalar_type() == common_scalar_type,
356-
"Unhandled dtype %s for %s",
357-
::executorch::runtime::toString(common_scalar_type),
358-
op_name);
359-
return internal::load_and_convert<CTYPE_COMMON, CTYPE_COMMON>;
360-
}
361-
}
362-
ET_CHECK(false);
363-
return nullptr;
364-
}
365-
366-
template <typename CTYPE_COMMON, const char* op_name>
367-
store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
368-
const Tensor& t,
369-
SupportedTensorDtypes dtypes) {
370-
switch (dtypes) {
371-
case SupportedTensorDtypes::REALHBBF16:
372-
return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
373-
case SupportedTensorDtypes::BOOL_OR_BYTE:
374-
return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(
375-
t);
376-
case SupportedTensorDtypes::SAME_AS_COMMON: {
377-
constexpr auto common_scalar_type =
378-
CppTypeToScalarType<CTYPE_COMMON>::value;
379-
ET_CHECK_MSG(
380-
t.scalar_type() == common_scalar_type,
381-
"Unhandled dtype %s for %s",
382-
::executorch::runtime::toString(common_scalar_type),
383-
op_name);
384-
return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
385-
}
386-
}
387-
ET_CHECK(false);
388-
return nullptr;
389-
}
390-
} // namespace internal
391-
392273
/**
393274
* Useful for binary elementwise operators. For each element of the inputs,
394275
* perform a computation and write to the corresponding element of the output.
@@ -432,56 +313,29 @@ inline void apply_binary_elementwise_fn(
432313
* Useful for ternary elementwise operators. For each element of the inputs,
433314
* perform a computation and write to the corresponding element of the output.
434315
* Tensor broadcasting is applied wherever it is required.
435-
*
436-
* In order to mitigate build time cost (straightforwardly |CTYPE_A| *
437-
* |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun
438-
* are passed as CTYPE_COMMON.
439-
*
440-
* Each tensor's supported dtypes set must be provided. The tensor
441-
* will be checked to ensure that its dtype falls into that set.
442-
*
443-
* op_name is used to support dtype selective build, as with the
444-
* ET_SWITCH family of macros. Note: because of C++17 quirks, you
445-
* can't pass a string literal for op_name. Instead, you should do the
446-
* following:
447-
*
448-
* static constexpr const char op_name[] = "my_op";
449-
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
450316
*/
451-
template <typename CTYPE_COMMON, const char* op_name, typename Op>
317+
template <
318+
typename CTYPE_A,
319+
typename CTYPE_B,
320+
typename CTYPE_C,
321+
typename CTYPE_OUT,
322+
typename Op>
452323
inline void apply_ternary_elementwise_fn(
453324
const Op& compute_fun,
454325
const Tensor& a,
455-
SupportedTensorDtypes a_dtypes,
456326
const Tensor& b,
457-
SupportedTensorDtypes b_dtypes,
458327
const Tensor& c,
459-
SupportedTensorDtypes c_dtypes,
460-
const Tensor& out,
461-
SupportedTensorDtypes out_dtypes) {
328+
const Tensor& out) {
462329
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
463330
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
464331
const bool c_is_broadcasted = !out.sizes().equals(c.sizes());
465332
const bool any_is_broadcasted =
466333
(a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
467334

468-
const auto load_a_to_common =
469-
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
470-
const auto load_b_to_common =
471-
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
472-
const auto load_c_to_common =
473-
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(c, c_dtypes);
474-
const auto store_common_to_out =
475-
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
476-
out, out_dtypes);
477-
const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
478-
const char* const data_b = reinterpret_cast<const char*>(b.const_data_ptr());
479-
const char* const data_c = reinterpret_cast<const char*>(c.const_data_ptr());
480-
const auto a_element_size = a.element_size();
481-
const auto b_element_size = b.element_size();
482-
const auto c_element_size = c.element_size();
483-
const auto out_element_size = out.element_size();
484-
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
335+
const CTYPE_A* const data_a = a.const_data_ptr<CTYPE_A>();
336+
const CTYPE_B* const data_b = b.const_data_ptr<CTYPE_B>();
337+
const CTYPE_C* const data_c = c.const_data_ptr<CTYPE_C>();
338+
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
485339

486340
for (size_t i = 0; i < out.numel(); ++i) {
487341
size_t a_linear_index = i;
@@ -503,11 +357,8 @@ inline void apply_ternary_elementwise_fn(
503357
}
504358
}
505359

506-
auto result = compute_fun(
507-
load_a_to_common(&data_a[a_linear_index * a_element_size]),
508-
load_b_to_common(&data_b[b_linear_index * b_element_size]),
509-
load_c_to_common(&data_c[c_linear_index * c_element_size]));
510-
store_common_to_out(result, &data_out[i * out_element_size]);
360+
data_out[i] = compute_fun(
361+
data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]);
511362
}
512363
}
513364

0 commit comments

Comments
 (0)