Skip to content

Commit 405f531

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Introduce REALHBF16. Binary ops: pow, rsub, sub (#6010)
Summary: Pull Request resolved: #6010 - pow: 1.03 M -> 17 K - rsub: 372 K -> 11 K - sub: 263 K -> 21 K ghstack-source-id: 246985133 exported-using-ghexport Reviewed By: swolchok Differential Revision: D63909724 fbshipit-source-id: ac173388c29e0cf25b5ac0893beaeecba5f52435
1 parent 7606476 commit 405f531

File tree

8 files changed

+248
-332
lines changed

8 files changed

+248
-332
lines changed

kernels/portable/cpu/op_pow.cpp

Lines changed: 95 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -9,101 +9,61 @@
99
#include <cmath>
1010

1111
#include <executorch/kernels/portable/cpu/scalar_utils.h>
12-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
13-
#include <executorch/kernels/portable/cpu/util/functional_util.h>
14-
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
12+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1513
#include <executorch/runtime/kernel/kernel_includes.h>
1614

1715
namespace torch {
1816
namespace executor {
1917
namespace native {
2018

21-
using Tensor = exec_aten::Tensor;
22-
23-
namespace {
24-
template <
25-
bool can_cast,
26-
typename CTYPE_A,
27-
typename CTYPE_B,
28-
typename CTYPE_IN,
29-
typename CTYPE_OUT>
30-
struct PowInner;
31-
32-
template <
33-
typename CTYPE_A,
34-
typename CTYPE_B,
35-
typename CTYPE_IN,
36-
typename CTYPE_OUT>
37-
struct PowInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
38-
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
39-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
40-
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
41-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
42-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
43-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
44-
CTYPE_IN value = std::pow(a_casted, b_casted);
45-
return static_cast<CTYPE_OUT>(value);
46-
},
47-
a,
48-
b,
49-
out);
50-
}
51-
};
52-
53-
struct ReportCanCastBug {
54-
static void run(const Tensor&, const Tensor&, Tensor&) {
55-
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
56-
}
57-
};
58-
59-
template <
60-
typename CTYPE_A,
61-
typename CTYPE_B,
62-
typename CTYPE_IN,
63-
typename CTYPE_OUT>
64-
struct PowInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
65-
: public ReportCanCastBug {};
66-
67-
} // namespace
68-
6919
Tensor& pow_Tensor_Tensor_out(
7020
KernelRuntimeContext& ctx,
7121
const Tensor& a,
7222
const Tensor& b,
7323
Tensor& out) {
74-
// Determine output size and resize for dynamic shapes
24+
// Common Dtype
25+
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
26+
27+
// Check Common Dtype
7528
ET_KERNEL_CHECK(
7629
ctx,
77-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
30+
(canCast(common_type, out.scalar_type()) &&
31+
common_type != ScalarType::Bool),
7832
InvalidArgument,
7933
out);
8034

81-
ScalarType a_type = a.scalar_type();
82-
ScalarType b_type = b.scalar_type();
83-
ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true);
84-
ScalarType out_type = out.scalar_type();
35+
// Check Dim Order
36+
ET_KERNEL_CHECK(
37+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
8538

39+
// Resize
8640
ET_KERNEL_CHECK(
87-
ctx, common_type != exec_aten::ScalarType::Bool, InvalidArgument, out);
88-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
89-
90-
ET_SWITCH_REALHB_TYPES(a_type, ctx, "pow.Tensor_Tensor_out", CTYPE_A, [&]() {
91-
ET_SWITCH_REALHB_TYPES(
92-
b_type, ctx, "pow.Tensor_Tensor_out", CTYPE_B, [&]() {
93-
using CTYPE_IN = typename torch::executor::
94-
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
95-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
96-
ET_SWITCH_REALH_TYPES(
97-
out_type, ctx, "pow.Tensor_Tensor_out", CTYPE_OUT, [&]() {
98-
PowInner<
99-
!std::is_same<CTYPE_IN, bool>::value &&
100-
can_cast<CTYPE_IN, CTYPE_OUT>::value,
101-
CTYPE_A,
102-
CTYPE_B,
103-
CTYPE_IN,
104-
CTYPE_OUT>::run(a, b, out);
105-
});
106-
});
41+
ctx,
42+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
43+
InvalidArgument,
44+
out);
45+
46+
// Compute Dtype
47+
ScalarType compute_type = utils::get_compute_type(common_type);
48+
if (compute_type != ScalarType::Float) {
49+
compute_type = ScalarType::Double;
50+
}
51+
52+
// @lint-ignore CLANGTIDY facebook-hte-CArray
53+
static constexpr const char op_name[] = "pow.Tensor_Tensor_out";
54+
55+
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
56+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
57+
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
58+
return std::pow(val_a, val_b);
59+
},
60+
ctx,
61+
a,
62+
utils::SupportedTensorDtypes::REALHBBF16,
63+
b,
64+
utils::SupportedTensorDtypes::REALHBBF16,
65+
out,
66+
utils::SupportedTensorDtypes::REALHBF16);
10767
});
10868

10969
return out;
@@ -114,51 +74,43 @@ Tensor& pow_Tensor_Scalar_out(
11474
const Tensor& a,
11575
const Scalar& b,
11676
Tensor& out) {
117-
(void)ctx;
77+
// Common Dtype
78+
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
11879

119-
// Resize for dynamic shape
120-
ET_KERNEL_CHECK_MSG(
80+
// Check Common Dtype
81+
ET_KERNEL_CHECK(
12182
ctx,
122-
resize_tensor(out, a.sizes()) == Error::Ok,
83+
(canCast(common_type, out.scalar_type()) &&
84+
common_type != ScalarType::Bool),
12385
InvalidArgument,
124-
out,
125-
"Failed to resize output tensor.");
86+
out);
12687

127-
ScalarType a_type = a.scalar_type();
128-
ScalarType b_type = utils::get_scalar_dtype(b);
129-
ScalarType common_type =
130-
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
131-
ScalarType out_type = out.scalar_type();
88+
// Check Dim Order
89+
ET_KERNEL_CHECK(
90+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
13291

133-
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
92+
// Resize
93+
ET_KERNEL_CHECK(
94+
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
13495

135-
if (common_type == ScalarType::Half) {
136-
common_type = ScalarType::Float;
96+
// Compute Dtype
97+
ScalarType compute_type = utils::get_compute_type(common_type);
98+
if (compute_type != ScalarType::Float) {
99+
compute_type = ScalarType::Double;
137100
}
138101

139-
ET_SWITCH_REALHB_TYPES(a_type, ctx, "pow.Tensor_Scalar_out", CTYPE_A, [&]() {
140-
ET_SWITCH_SCALAR_OBJ_TYPES(
141-
b_type, ctx, "pow.Tensor_Scalar_out", CTYPE_B, [&]() {
142-
ET_SWITCH_REAL_TYPES(
143-
common_type, ctx, "pow.Tensor_Scalar_out", CTYPE_IN, [&]() {
144-
ET_SWITCH_REALH_TYPES(
145-
out_type, ctx, "pow.Tensor_Scalar_out", CTYPE_OUT, [&]() {
146-
CTYPE_B val_b = 0;
147-
utils::extract_scalar(b, &val_b);
148-
apply_unary_map_fn(
149-
[val_b](const CTYPE_A val_a) {
150-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
151-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
152-
CTYPE_IN value = std::pow(a_casted, b_casted);
153-
154-
return static_cast<CTYPE_OUT>(value);
155-
},
156-
a.const_data_ptr<CTYPE_A>(),
157-
out.mutable_data_ptr<CTYPE_OUT>(),
158-
out.numel());
159-
});
160-
});
161-
});
102+
// @lint-ignore CLANGTIDY facebook-hte-CArray
103+
static constexpr const char op_name[] = "pow.Tensor_Scalar_out";
104+
105+
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
106+
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
107+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
108+
[val_b](const CTYPE_COMPUTE val_a) { return std::pow(val_a, val_b); },
109+
ctx,
110+
a,
111+
utils::SupportedTensorDtypes::REALHBBF16,
112+
out,
113+
utils::SupportedTensorDtypes::REALHBF16);
162114
});
163115

164116
return out;
@@ -169,49 +121,43 @@ Tensor& pow_Scalar_out(
169121
const Scalar& a,
170122
const Tensor& b,
171123
Tensor& out) {
172-
(void)ctx;
124+
// Common Dtype
125+
ScalarType common_type = utils::promote_type_with_scalar(b.scalar_type(), a);
173126

174-
// Resize for dynamic shape
175-
ET_KERNEL_CHECK_MSG(
127+
// Check Common Dtype
128+
ET_KERNEL_CHECK(
176129
ctx,
177-
resize_tensor(out, b.sizes()) == Error::Ok,
130+
(canCast(common_type, out.scalar_type()) &&
131+
common_type != ScalarType::Bool),
178132
InvalidArgument,
179-
out,
180-
"Failed to resize output tensor.");
133+
out);
181134

182-
ScalarType a_type = utils::get_scalar_dtype(a);
183-
ScalarType b_type = b.scalar_type();
184-
ScalarType common_type =
185-
utils::promote_type_with_scalar(b_type, a, /*half_to_float*/ false);
186-
ScalarType out_type = out.scalar_type();
135+
// Check Dim Order
136+
ET_KERNEL_CHECK(
137+
ctx, tensors_have_same_dim_order(b, out), InvalidArgument, out);
187138

188-
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
139+
// Resize
140+
ET_KERNEL_CHECK(
141+
ctx, resize_tensor(out, b.sizes()) == Error::Ok, InvalidArgument, out);
189142

190-
if (common_type == ScalarType::Half) {
191-
common_type = ScalarType::Float;
143+
// Compute Dtype
144+
ScalarType compute_type = utils::get_compute_type(common_type);
145+
if (compute_type != ScalarType::Float) {
146+
compute_type = ScalarType::Double;
192147
}
193148

194-
ET_SWITCH_SCALAR_OBJ_TYPES(a_type, ctx, "pow.Scalar_out", CTYPE_A, [&]() {
195-
ET_SWITCH_REALHB_TYPES(b_type, ctx, "pow.Scalar_out", CTYPE_B, [&]() {
196-
ET_SWITCH_REAL_TYPES(common_type, ctx, "pow.Scalar_out", CTYPE_IN, [&]() {
197-
ET_SWITCH_REALH_TYPES(
198-
out_type, ctx, "pow.Scalar_out", CTYPE_OUT, [&]() {
199-
CTYPE_A val_a = 0;
200-
utils::extract_scalar(a, &val_a);
201-
202-
apply_unary_map_fn(
203-
[val_a](const CTYPE_B val_b) {
204-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
205-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
206-
CTYPE_IN value = std::pow(a_casted, b_casted);
207-
return static_cast<CTYPE_OUT>(value);
208-
},
209-
b.const_data_ptr<CTYPE_B>(),
210-
out.mutable_data_ptr<CTYPE_OUT>(),
211-
out.numel());
212-
});
213-
});
214-
});
149+
// @lint-ignore CLANGTIDY facebook-hte-CArray
150+
static constexpr const char op_name[] = "pow.Scalar_out";
151+
152+
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
153+
const CTYPE_COMPUTE val_a = utils::scalar_to<CTYPE_COMPUTE>(a);
154+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
155+
[val_a](const CTYPE_COMPUTE val_b) { return std::pow(val_a, val_b); },
156+
ctx,
157+
b,
158+
utils::SupportedTensorDtypes::REALHBBF16,
159+
out,
160+
utils::SupportedTensorDtypes::REALHBF16);
215161
});
216162

217163
return out;

kernels/portable/cpu/op_rsub.cpp

Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
*/
88

99
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10-
#include <executorch/kernels/portable/cpu/util/functional_util.h>
11-
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1211
#include <executorch/runtime/kernel/kernel_includes.h>
1312

1413
namespace torch {
@@ -21,57 +20,47 @@ Tensor& rsub_scalar_out(
2120
const Scalar& b,
2221
const Scalar& alpha,
2322
Tensor& out) {
24-
(void)ctx;
23+
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
24+
25+
// Check alpha type
26+
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
2527

26-
// Resize for dynamic shape
27-
ET_KERNEL_CHECK_MSG(
28+
// Common Dtype
29+
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
30+
31+
// Check Common Dtype
32+
ET_KERNEL_CHECK(
2833
ctx,
29-
resize_tensor(out, a.sizes()) == Error::Ok,
34+
(common_type == out.scalar_type() && canCast(alpha_type, common_type)),
3035
InvalidArgument,
31-
out,
32-
"Failed to resize output tensor.");
36+
out);
3337

38+
// Check Dim Order
3439
ET_KERNEL_CHECK(
3540
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
3641

37-
ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out);
38-
39-
ScalarType a_type = a.scalar_type();
40-
ScalarType b_type = utils::get_scalar_dtype(b);
41-
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
42-
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);
43-
ScalarType out_type = out.scalar_type();
44-
45-
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
42+
// Resize
4643
ET_KERNEL_CHECK(
47-
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);
48-
ET_KERNEL_CHECK(ctx, tensor_is_real_type(out), InvalidArgument, out);
44+
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
45+
46+
// Compute Dtype
47+
ScalarType compute_type = utils::get_compute_type(common_type);
4948

50-
ET_SWITCH_REAL_TYPES(a_type, ctx, "rsub.Scalar_out", CTYPE_A, [&]() {
51-
ET_SWITCH_SCALAR_OBJ_REAL_TYPES(
52-
b_type, ctx, "rsub.Scalar_out", CTYPE_B, [&]() {
53-
ET_SWITCH_REAL_TYPES(
54-
common_type, ctx, "rsub.Scalar_out", CTYPE_IN, [&]() {
55-
ET_SWITCH_REAL_TYPES(
56-
out_type, ctx, "rsub.Scalar_out", CTYPE_OUT, [&]() {
57-
CTYPE_B b_val;
58-
utils::extract_scalar(b, &b_val);
59-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
60-
CTYPE_IN alpha_val;
61-
utils::extract_scalar(alpha, &alpha_val);
49+
// @lint-ignore CLANGTIDY facebook-hte-CArray
50+
static constexpr const char op_name[] = "rsub.Scalar_out";
6251

63-
apply_unary_map_fn(
64-
[b_casted, alpha_val](const CTYPE_A val_a) {
65-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
66-
CTYPE_IN value = b_casted - alpha_val * a_casted;
67-
return static_cast<CTYPE_OUT>(value);
68-
},
69-
a.const_data_ptr<CTYPE_A>(),
70-
out.mutable_data_ptr<CTYPE_OUT>(),
71-
out.numel());
72-
});
73-
});
74-
});
52+
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
53+
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
54+
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
55+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
56+
[val_b, val_alpha](const CTYPE_COMPUTE val_a) {
57+
return val_b - val_alpha * val_a;
58+
},
59+
ctx,
60+
a,
61+
utils::SupportedTensorDtypes::REALHBF16,
62+
out,
63+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
7564
});
7665

7766
return out;

0 commit comments

Comments
 (0)