Skip to content

Commit 85fbe0d

Browse files
[ET][Portable][Build Size] REALHBF16 binary ops: maximum, minimum, mul
- mul: 1.69 M -> 15 K - maximum: 353 K -> 11 K - minimum: 353 K -> 11 K Differential Revision: [D63909726](https://our.internmc.facebook.com/intern/diff/D63909726/) [ghstack-poisoned]
1 parent bc8dcaa commit 85fbe0d

File tree

4 files changed

+108
-265
lines changed

4 files changed

+108
-265
lines changed

kernels/portable/cpu/op_maximum.cpp

Lines changed: 27 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -7,98 +7,54 @@
77
*/
88

99
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1111
#include <executorch/kernels/portable/cpu/util/math_util.h>
1212
#include <executorch/runtime/kernel/kernel_includes.h>
1313

1414
namespace torch {
1515
namespace executor {
1616
namespace native {
17-
namespace {
18-
19-
template <
20-
bool can_cast,
21-
typename CTYPE_A,
22-
typename CTYPE_B,
23-
typename CTYPE_IN,
24-
typename CTYPE_OUT>
25-
struct MaximumInner;
26-
27-
template <
28-
typename CTYPE_A,
29-
typename CTYPE_B,
30-
typename CTYPE_IN,
31-
typename CTYPE_OUT>
32-
struct MaximumInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
33-
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
34-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
35-
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
36-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
37-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
38-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
39-
CTYPE_IN value = utils::max_override(a_casted, b_casted);
40-
41-
return static_cast<CTYPE_OUT>(value);
42-
},
43-
a,
44-
b,
45-
out);
46-
}
47-
};
48-
49-
struct ReportCanCastBug {
50-
static void run(const Tensor&, const Tensor&, Tensor&) {
51-
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
52-
}
53-
};
54-
55-
template <
56-
typename CTYPE_A,
57-
typename CTYPE_B,
58-
typename CTYPE_IN,
59-
typename CTYPE_OUT>
60-
struct MaximumInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
61-
: public ReportCanCastBug {};
62-
63-
} // namespace
6417

6518
Tensor& maximum_out(
6619
KernelRuntimeContext& ctx,
6720
const Tensor& a,
6821
const Tensor& b,
6922
Tensor& out) {
70-
(void)ctx;
23+
// Common Dtype
24+
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
25+
26+
// Check Common Dtype
27+
ET_KERNEL_CHECK(
28+
ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);
7129

30+
// Check Dim Order
31+
ET_KERNEL_CHECK(
32+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
33+
34+
// Resize
7235
ET_KERNEL_CHECK(
7336
ctx,
7437
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
7538
InvalidArgument,
7639
out);
7740

78-
ET_KERNEL_CHECK(
79-
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
80-
81-
ScalarType a_type = a.scalar_type();
82-
ScalarType b_type = b.scalar_type();
83-
ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true);
84-
ScalarType out_type = out.scalar_type();
41+
// Compute Dtype
42+
ScalarType compute_type = utils::get_compute_type(common_type);
8543

86-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
44+
static constexpr const char op_name[] = "maximum.out";
8745

88-
ET_SWITCH_REALHB_TYPES(a_type, ctx, "maximum.out", CTYPE_A, [&]() {
89-
ET_SWITCH_REALHB_TYPES(b_type, ctx, "maximum.out", CTYPE_B, [&]() {
90-
using CTYPE_IN = typename torch::executor::
91-
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
92-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
93-
ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() {
94-
MaximumInner<
95-
can_cast<CTYPE_IN, CTYPE_OUT>::value,
96-
CTYPE_A,
97-
CTYPE_B,
98-
CTYPE_IN,
99-
CTYPE_OUT>::run(a, b, out);
100-
});
101-
});
46+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
47+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
48+
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
49+
return utils::max_override(val_a, val_b);
50+
},
51+
ctx,
52+
a,
53+
utils::SupportedTensorDtypes::REALHBBF16,
54+
b,
55+
utils::SupportedTensorDtypes::REALHBBF16,
56+
out,
57+
utils::SupportedTensorDtypes::REALHBBF16);
10258
});
10359

10460
return out;

kernels/portable/cpu/op_minimum.cpp

Lines changed: 27 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -7,98 +7,54 @@
77
*/
88

99
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1111
#include <executorch/kernels/portable/cpu/util/math_util.h>
1212
#include <executorch/runtime/kernel/kernel_includes.h>
1313

1414
namespace torch {
1515
namespace executor {
1616
namespace native {
17-
namespace {
18-
19-
template <
20-
bool can_cast,
21-
typename CTYPE_A,
22-
typename CTYPE_B,
23-
typename CTYPE_IN,
24-
typename CTYPE_OUT>
25-
struct MinimumInner;
26-
27-
template <
28-
typename CTYPE_A,
29-
typename CTYPE_B,
30-
typename CTYPE_IN,
31-
typename CTYPE_OUT>
32-
struct MinimumInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
33-
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
34-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
35-
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
36-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
37-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
38-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
39-
CTYPE_IN value = utils::min_override(a_casted, b_casted);
40-
41-
return static_cast<CTYPE_OUT>(value);
42-
},
43-
a,
44-
b,
45-
out);
46-
}
47-
};
48-
49-
struct ReportCanCastBug {
50-
static void run(const Tensor&, const Tensor&, Tensor&) {
51-
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
52-
}
53-
};
54-
55-
template <
56-
typename CTYPE_A,
57-
typename CTYPE_B,
58-
typename CTYPE_IN,
59-
typename CTYPE_OUT>
60-
struct MinimumInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
61-
: public ReportCanCastBug {};
62-
63-
} // namespace
6417

6518
Tensor& minimum_out(
6619
KernelRuntimeContext& ctx,
6720
const Tensor& a,
6821
const Tensor& b,
6922
Tensor& out) {
70-
(void)ctx;
23+
// Common Dtype
24+
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
25+
26+
// Check Common Dtype
27+
ET_KERNEL_CHECK(
28+
ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);
7129

30+
// Check Dim Order
31+
ET_KERNEL_CHECK(
32+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
33+
34+
// Resize
7235
ET_KERNEL_CHECK(
7336
ctx,
7437
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
7538
InvalidArgument,
7639
out);
7740

78-
ET_KERNEL_CHECK(
79-
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
80-
81-
ScalarType a_type = a.scalar_type();
82-
ScalarType b_type = b.scalar_type();
83-
ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true);
84-
ScalarType out_type = out.scalar_type();
41+
// Compute Dtype
42+
ScalarType compute_type = utils::get_compute_type(common_type);
8543

86-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
44+
static constexpr const char op_name[] = "minimum.out";
8745

88-
ET_SWITCH_REALHB_TYPES(a_type, ctx, "minimum.out", CTYPE_A, [&]() {
89-
ET_SWITCH_REALHB_TYPES(b_type, ctx, "minimum.out", CTYPE_B, [&]() {
90-
using CTYPE_IN = typename torch::executor::
91-
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
92-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
93-
ET_SWITCH_REALHB_TYPES(out_type, ctx, "minimum.out", CTYPE_OUT, [&]() {
94-
MinimumInner<
95-
can_cast<CTYPE_IN, CTYPE_OUT>::value,
96-
CTYPE_A,
97-
CTYPE_B,
98-
CTYPE_IN,
99-
CTYPE_OUT>::run(a, b, out);
100-
});
101-
});
46+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
47+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
48+
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
49+
return utils::min_override(val_a, val_b);
50+
},
51+
ctx,
52+
a,
53+
utils::SupportedTensorDtypes::REALHBBF16,
54+
b,
55+
utils::SupportedTensorDtypes::REALHBBF16,
56+
out,
57+
utils::SupportedTensorDtypes::REALHBBF16);
10258
});
10359

10460
return out;

0 commit comments

Comments
 (0)