Skip to content

Commit 7606476

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
REALHBF16 binary ops: maximum, minimum, mul (#6009)
Summary: Pull Request resolved: #6009 - mul: 1.69 M -> 15 K - maximum: 353 K -> 11 K - minimum: 353 K -> 11 K ghstack-source-id: 246985131 exported-using-ghexport Reviewed By: swolchok Differential Revision: D63909726 fbshipit-source-id: 83b87bd789026194f65067ea952c57b83266583f
1 parent f8d182b commit 7606476

File tree

5 files changed

+119
-265
lines changed

5 files changed

+119
-265
lines changed

kernels/portable/cpu/op_maximum.cpp

Lines changed: 28 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -7,98 +7,55 @@
77
*/
88

99
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1111
#include <executorch/kernels/portable/cpu/util/math_util.h>
1212
#include <executorch/runtime/kernel/kernel_includes.h>
1313

1414
namespace torch {
1515
namespace executor {
1616
namespace native {
17-
namespace {
18-
19-
template <
20-
bool can_cast,
21-
typename CTYPE_A,
22-
typename CTYPE_B,
23-
typename CTYPE_IN,
24-
typename CTYPE_OUT>
25-
struct MaximumInner;
26-
27-
template <
28-
typename CTYPE_A,
29-
typename CTYPE_B,
30-
typename CTYPE_IN,
31-
typename CTYPE_OUT>
32-
struct MaximumInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
33-
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
34-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
35-
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
36-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
37-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
38-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
39-
CTYPE_IN value = utils::max_override(a_casted, b_casted);
40-
41-
return static_cast<CTYPE_OUT>(value);
42-
},
43-
a,
44-
b,
45-
out);
46-
}
47-
};
48-
49-
struct ReportCanCastBug {
50-
static void run(const Tensor&, const Tensor&, Tensor&) {
51-
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
52-
}
53-
};
54-
55-
template <
56-
typename CTYPE_A,
57-
typename CTYPE_B,
58-
typename CTYPE_IN,
59-
typename CTYPE_OUT>
60-
struct MaximumInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
61-
: public ReportCanCastBug {};
62-
63-
} // namespace
6417

6518
Tensor& maximum_out(
6619
KernelRuntimeContext& ctx,
6720
const Tensor& a,
6821
const Tensor& b,
6922
Tensor& out) {
70-
(void)ctx;
23+
// Common Dtype
24+
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
25+
26+
// Check Common Dtype
27+
ET_KERNEL_CHECK(
28+
ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);
7129

30+
// Check Dim Order
31+
ET_KERNEL_CHECK(
32+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
33+
34+
// Resize
7235
ET_KERNEL_CHECK(
7336
ctx,
7437
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
7538
InvalidArgument,
7639
out);
7740

78-
ET_KERNEL_CHECK(
79-
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
80-
81-
ScalarType a_type = a.scalar_type();
82-
ScalarType b_type = b.scalar_type();
83-
ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true);
84-
ScalarType out_type = out.scalar_type();
41+
// Compute Dtype
42+
ScalarType compute_type = utils::get_compute_type(common_type);
8543

86-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
44+
// @lint-ignore CLANGTIDY facebook-hte-CArray
45+
static constexpr const char op_name[] = "maximum.out";
8746

88-
ET_SWITCH_REALHB_TYPES(a_type, ctx, "maximum.out", CTYPE_A, [&]() {
89-
ET_SWITCH_REALHB_TYPES(b_type, ctx, "maximum.out", CTYPE_B, [&]() {
90-
using CTYPE_IN = typename torch::executor::
91-
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
92-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
93-
ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() {
94-
MaximumInner<
95-
can_cast<CTYPE_IN, CTYPE_OUT>::value,
96-
CTYPE_A,
97-
CTYPE_B,
98-
CTYPE_IN,
99-
CTYPE_OUT>::run(a, b, out);
100-
});
101-
});
47+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
48+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
49+
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
50+
return utils::max_override(val_a, val_b);
51+
},
52+
ctx,
53+
a,
54+
utils::SupportedTensorDtypes::REALHBBF16,
55+
b,
56+
utils::SupportedTensorDtypes::REALHBBF16,
57+
out,
58+
utils::SupportedTensorDtypes::REALHBBF16);
10259
});
10360

10461
return out;

kernels/portable/cpu/op_minimum.cpp

Lines changed: 28 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -7,98 +7,55 @@
77
*/
88

99
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1111
#include <executorch/kernels/portable/cpu/util/math_util.h>
1212
#include <executorch/runtime/kernel/kernel_includes.h>
1313

1414
namespace torch {
1515
namespace executor {
1616
namespace native {
17-
namespace {
18-
19-
template <
20-
bool can_cast,
21-
typename CTYPE_A,
22-
typename CTYPE_B,
23-
typename CTYPE_IN,
24-
typename CTYPE_OUT>
25-
struct MinimumInner;
26-
27-
template <
28-
typename CTYPE_A,
29-
typename CTYPE_B,
30-
typename CTYPE_IN,
31-
typename CTYPE_OUT>
32-
struct MinimumInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
33-
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
34-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
35-
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
36-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
37-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
38-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
39-
CTYPE_IN value = utils::min_override(a_casted, b_casted);
40-
41-
return static_cast<CTYPE_OUT>(value);
42-
},
43-
a,
44-
b,
45-
out);
46-
}
47-
};
48-
49-
struct ReportCanCastBug {
50-
static void run(const Tensor&, const Tensor&, Tensor&) {
51-
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
52-
}
53-
};
54-
55-
template <
56-
typename CTYPE_A,
57-
typename CTYPE_B,
58-
typename CTYPE_IN,
59-
typename CTYPE_OUT>
60-
struct MinimumInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
61-
: public ReportCanCastBug {};
62-
63-
} // namespace
6417

6518
Tensor& minimum_out(
6619
KernelRuntimeContext& ctx,
6720
const Tensor& a,
6821
const Tensor& b,
6922
Tensor& out) {
70-
(void)ctx;
23+
// Common Dtype
24+
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
25+
26+
// Check Common Dtype
27+
ET_KERNEL_CHECK(
28+
ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);
7129

30+
// Check Dim Order
31+
ET_KERNEL_CHECK(
32+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
33+
34+
// Resize
7235
ET_KERNEL_CHECK(
7336
ctx,
7437
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
7538
InvalidArgument,
7639
out);
7740

78-
ET_KERNEL_CHECK(
79-
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
80-
81-
ScalarType a_type = a.scalar_type();
82-
ScalarType b_type = b.scalar_type();
83-
ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true);
84-
ScalarType out_type = out.scalar_type();
41+
// Compute Dtype
42+
ScalarType compute_type = utils::get_compute_type(common_type);
8543

86-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
44+
// @lint-ignore CLANGTIDY facebook-hte-CArray
45+
static constexpr const char op_name[] = "minimum.out";
8746

88-
ET_SWITCH_REALHB_TYPES(a_type, ctx, "minimum.out", CTYPE_A, [&]() {
89-
ET_SWITCH_REALHB_TYPES(b_type, ctx, "minimum.out", CTYPE_B, [&]() {
90-
using CTYPE_IN = typename torch::executor::
91-
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
92-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
93-
ET_SWITCH_REALHB_TYPES(out_type, ctx, "minimum.out", CTYPE_OUT, [&]() {
94-
MinimumInner<
95-
can_cast<CTYPE_IN, CTYPE_OUT>::value,
96-
CTYPE_A,
97-
CTYPE_B,
98-
CTYPE_IN,
99-
CTYPE_OUT>::run(a, b, out);
100-
});
101-
});
47+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
48+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
49+
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
50+
return utils::min_override(val_a, val_b);
51+
},
52+
ctx,
53+
a,
54+
utils::SupportedTensorDtypes::REALHBBF16,
55+
b,
56+
utils::SupportedTensorDtypes::REALHBBF16,
57+
out,
58+
utils::SupportedTensorDtypes::REALHBBF16);
10259
});
10360

10461
return out;

0 commit comments

Comments
 (0)