Skip to content

Commit 607d4a3

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Introduce notion of compute type. Refactor add/clamp/where (#6007)
Summary: Pull Request resolved: #6007 Introduced notion of compute type. We now perform the computation over `CTYPE_COMP` (the compute type) rather than `CTYPE_COMMON` (the common type) All of the occurrences of `CTYPE_COMMON` in elementwise_util.h need to be replaced with `CTYPE_COMP`, to properly reflect that we are dealing with the computation type, and not the common type. But we don't do that in this diff, to facilitate review. The previous `SupportedTensorDtypes::SAME_AS_COMMON` is transformed into `SupportedTensorDtypes::SAME_AS_COMP` and a newer `SupportedTensorDtypes::SAME_AS_COMMON` is written. `SupportedTensorDtypes::SAME_AS_COMMON` should perform the reverse mapping than get_compute_type(). In this case, this means that when `CTYPE_COMP` is anything but float, `SAME_AS_COMMON` is effectively the same as `SAME_AS_COMP`. But when `CTYPE_COMP` is float, `SAME_AS_COMMON` switches over `Float`, `Half` and `BFloat16` Build size reduction: - add: 21K -> 18K - clamp: 28K -> 23K - where: 16K -> 12K ghstack-source-id: 246981810 exported-using-ghexport Reviewed By: swolchok Differential Revision: D63860791 fbshipit-source-id: 30763a1f456a701eeb15d2a7b37ae28825579625
1 parent 4153371 commit 607d4a3

File tree

4 files changed

+243
-129
lines changed

4 files changed

+243
-129
lines changed

kernels/portable/cpu/op_add.cpp

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,36 +24,44 @@ Tensor& add_out(
2424
Tensor& out) {
2525
ET_KERNEL_CHECK(
2626
ctx,
27-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
27+
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
28+
executorch::runtime::tensor_is_realhbbf16_type(b) &&
29+
executorch::runtime::tensor_is_realhbbf16_type(out)),
2830
InvalidArgument,
2931
out);
3032

33+
// Common Dtype
34+
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
35+
36+
// Check Common Dtype
3137
ET_KERNEL_CHECK(
3238
ctx,
33-
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
34-
executorch::runtime::tensor_is_realhbbf16_type(b) &&
35-
executorch::runtime::tensor_is_realhbbf16_type(out)),
39+
(canCast(common_type, out.scalar_type()) &&
40+
check_alpha_type(utils::get_scalar_dtype(alpha), common_type)),
3641
InvalidArgument,
3742
out);
43+
44+
// Check Dim Order
3845
ET_KERNEL_CHECK(
3946
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
4047

41-
ScalarType a_type = a.scalar_type();
42-
ScalarType b_type = b.scalar_type();
43-
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
44-
ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true);
45-
ScalarType out_type = out.scalar_type();
46-
47-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
48+
// Resize
4849
ET_KERNEL_CHECK(
49-
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);
50+
ctx,
51+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
52+
InvalidArgument,
53+
out);
54+
55+
// Compute Dtype
56+
ScalarType compute_type = utils::get_compute_type(common_type);
5057

58+
// @lint-ignore CLANGTIDY facebook-hte-CArray
5159
static constexpr const char op_name[] = "add.out";
5260

53-
ET_SWITCH_REALB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
54-
utils::apply_bitensor_elementwise_fn<CTYPE_COMMON, op_name>(
55-
[alpha](const CTYPE_COMMON val_a, const CTYPE_COMMON val_b) {
56-
CTYPE_COMMON val_alpha = utils::scalar_to<CTYPE_COMMON>(alpha);
61+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
62+
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
63+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
64+
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
5765
return val_a + val_alpha * val_b;
5866
},
5967
a,
@@ -73,52 +81,49 @@ Tensor& add_scalar_out(
7381
const Scalar& b,
7482
const Scalar& alpha,
7583
Tensor& out) {
76-
(void)ctx;
77-
78-
// Resize for dynamic shape
79-
ET_KERNEL_CHECK_MSG(
84+
ET_KERNEL_CHECK(
8085
ctx,
81-
resize_tensor(out, a.sizes()) == Error::Ok,
86+
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
87+
executorch::runtime::tensor_is_realhbbf16_type(out)),
8288
InvalidArgument,
83-
out,
84-
"Failed to resize output tensor.");
89+
out);
8590

91+
// Common Dtype
92+
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
93+
94+
// Check Common Dtype
8695
ET_KERNEL_CHECK(
8796
ctx,
88-
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
89-
executorch::runtime::tensor_is_realhbbf16_type(out)),
97+
(common_type == out.scalar_type() &&
98+
check_alpha_type(utils::get_scalar_dtype(alpha), common_type)),
9099
InvalidArgument,
91100
out);
101+
102+
// Check Dim Order
92103
ET_KERNEL_CHECK(
93104
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
94105

95-
ScalarType a_type = a.scalar_type();
96-
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
97-
ScalarType common_type =
98-
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
99-
ScalarType out_type = out.scalar_type();
100-
101-
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
106+
// Resize
102107
ET_KERNEL_CHECK(
103-
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);
108+
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
104109

105-
if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
106-
common_type = ScalarType::Float;
107-
}
110+
// Compute Dtype
111+
ScalarType compute_type = utils::get_compute_type(common_type);
108112

113+
// @lint-ignore CLANGTIDY facebook-hte-CArray
109114
static constexpr const char op_name[] = "add.Scalar_out";
110115

111-
ET_SWITCH_REALB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
112-
utils::apply_unitensor_elementwise_fn<CTYPE_COMMON, op_name>(
113-
[b, alpha](const CTYPE_COMMON val_a) {
114-
CTYPE_COMMON val_b = utils::scalar_to<CTYPE_COMMON>(b);
115-
CTYPE_COMMON val_alpha = utils::scalar_to<CTYPE_COMMON>(alpha);
116+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
117+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
118+
[b, alpha](const CTYPE_COMPUTE val_a) {
119+
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
120+
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
116121
return val_a + val_alpha * val_b;
117122
},
118123
a,
119124
utils::SupportedTensorDtypes::REALHBBF16,
120125
out,
121-
utils::SupportedTensorDtypes::REALHBBF16);
126+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
122127
});
123128

124129
return out;

kernels/portable/cpu/op_clamp.cpp

Lines changed: 78 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -73,74 +73,91 @@ Tensor& clamp_out(
7373
const exec_aten::optional<Scalar>& min_opt,
7474
const exec_aten::optional<Scalar>& max_opt,
7575
Tensor& out) {
76-
(void)ctx;
76+
ET_KERNEL_CHECK(
77+
ctx,
78+
(executorch::runtime::tensor_is_realhbbf16_type(in) &&
79+
executorch::runtime::tensor_is_realhbbf16_type(out)),
80+
InvalidArgument,
81+
out);
82+
83+
bool has_min = min_opt.has_value();
84+
bool has_max = max_opt.has_value();
7785

7886
ET_KERNEL_CHECK_MSG(
7987
ctx,
80-
resize_tensor(out, in.sizes()) == Error::Ok,
88+
has_min || has_max,
8189
InvalidArgument,
8290
out,
83-
"Failed to resize output tensor.");
84-
85-
ET_KERNEL_CHECK(
86-
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
91+
"At least one of 'min' or 'max' must not be None");
8792

93+
// Input Dtypes
8894
ScalarType in_type = in.scalar_type();
89-
ScalarType min_type = in_type;
90-
ScalarType max_type = in_type;
91-
ScalarType common_type = in_type;
95+
ScalarType min_type =
96+
has_min ? utils::get_scalar_dtype(min_opt.value()) : in_type;
97+
ScalarType max_type =
98+
has_max ? utils::get_scalar_dtype(max_opt.value()) : in_type;
9299
ScalarType out_type = out.scalar_type();
93100

94-
bool has_min = min_opt.has_value();
101+
// Common Dtype
102+
ScalarType common_type = in_type;
95103
if (has_min) {
96-
min_type = utils::get_scalar_dtype(min_opt.value());
97104
common_type = utils::promote_type_with_scalar(common_type, min_opt.value());
105+
}
106+
if (has_max) {
107+
common_type = utils::promote_type_with_scalar(common_type, max_opt.value());
108+
}
109+
110+
// Check Common Dtype
111+
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
112+
113+
// Check Scalar Bounds
114+
if (has_min) {
98115
ET_KERNEL_CHECK(
99116
ctx,
100117
check_bounds(min_opt.value(), min_type, out_type, "minimum"),
101118
InvalidArgument,
102119
out);
103120
}
104-
bool has_max = max_opt.has_value();
105121
if (has_max) {
106-
max_type = utils::get_scalar_dtype(max_opt.value());
107-
common_type = utils::promote_type_with_scalar(common_type, max_opt.value());
108122
ET_KERNEL_CHECK(
109123
ctx,
110124
check_bounds(max_opt.value(), max_type, out_type, "maximum"),
111125
InvalidArgument,
112126
out);
113127
}
114128

115-
ET_KERNEL_CHECK_MSG(
116-
ctx,
117-
has_min || has_max,
118-
InvalidArgument,
119-
out,
120-
"At least one of 'min' or 'max' must not be None");
129+
// Check Dim Order
130+
ET_KERNEL_CHECK(
131+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
121132

122-
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
133+
// Resize
134+
ET_KERNEL_CHECK(
135+
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
136+
137+
// Compute Dtype
138+
ScalarType compute_type = utils::get_compute_type(common_type);
123139

140+
// @lint-ignore CLANGTIDY facebook-hte-CArray
124141
static constexpr const char op_name[] = "clamp.out";
125142

126-
ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
127-
utils::apply_unitensor_elementwise_fn<CTYPE_COMMON, op_name>(
128-
[has_min, min_opt, has_max, max_opt](const CTYPE_COMMON val_in) {
129-
CTYPE_COMMON val_out = val_in;
143+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
144+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
145+
[has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) {
146+
CTYPE_COMPUTE val_out = val_in;
130147
if (has_min) {
131148
val_out = utils::max_override(
132-
val_out, utils::scalar_to<CTYPE_COMMON>(min_opt.value()));
149+
val_out, utils::scalar_to<CTYPE_COMPUTE>(min_opt.value()));
133150
}
134151
if (has_max) {
135152
val_out = utils::min_override(
136-
val_out, utils::scalar_to<CTYPE_COMMON>(max_opt.value()));
153+
val_out, utils::scalar_to<CTYPE_COMPUTE>(max_opt.value()));
137154
}
138155
return val_out;
139156
},
140157
in,
141158
utils::SupportedTensorDtypes::REALHBBF16,
142159
out,
143-
utils::SupportedTensorDtypes::REALHBBF16);
160+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
144161
});
145162

146163
return out;
@@ -152,8 +169,6 @@ Tensor& clamp_tensor_out(
152169
const exec_aten::optional<Tensor>& min_opt,
153170
const exec_aten::optional<Tensor>& max_opt,
154171
Tensor& out) {
155-
(void)ctx;
156-
157172
bool has_min = min_opt.has_value();
158173
bool has_max = max_opt.has_value();
159174

@@ -167,42 +182,55 @@ Tensor& clamp_tensor_out(
167182
const Tensor& min = has_min ? min_opt.value() : in;
168183
const Tensor& max = has_max ? max_opt.value() : in;
169184

185+
ET_KERNEL_CHECK(
186+
ctx,
187+
(executorch::runtime::tensor_is_realhbbf16_type(in) &&
188+
executorch::runtime::tensor_is_realhbbf16_type(min) &&
189+
executorch::runtime::tensor_is_realhbbf16_type(max) &&
190+
executorch::runtime::tensor_is_realhbbf16_type(out)),
191+
InvalidArgument,
192+
out);
193+
194+
// Common Dtype
195+
ScalarType common_type = in.scalar_type();
196+
if (has_min) {
197+
common_type = promoteTypes(common_type, min.scalar_type());
198+
}
199+
if (has_max) {
200+
common_type = promoteTypes(common_type, max.scalar_type());
201+
}
202+
203+
// Check Common Dtype
204+
ET_KERNEL_CHECK(
205+
ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);
206+
207+
// Check Dim Order
170208
ET_KERNEL_CHECK(
171209
ctx,
172210
tensors_have_same_dim_order(in, min, max, out),
173211
InvalidArgument,
174212
out);
175213

214+
// Resize
176215
ET_KERNEL_CHECK(
177216
ctx,
178217
resize_to_broadcast_target_size(in, min, max, out) == Error::Ok,
179218
InvalidArgument,
180219
out);
181220

182-
ScalarType in_type = in.scalar_type();
183-
ScalarType min_type = min.scalar_type();
184-
ScalarType max_type = max.scalar_type();
185-
ScalarType common_type = in_type;
186-
ScalarType out_type = out.scalar_type();
187-
188-
if (has_min) {
189-
common_type = promoteTypes(common_type, min_type, /*half_to_float*/ true);
190-
}
191-
if (has_max) {
192-
common_type = promoteTypes(common_type, max_type, /*half_to_float*/ true);
193-
}
194-
195-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
221+
// Compute Dtype
222+
ScalarType compute_type = utils::get_compute_type(common_type);
196223

224+
// @lint-ignore CLANGTIDY facebook-hte-CArray
197225
static constexpr const char op_name[] = "clamp.Tensor_out";
198226

199-
ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
200-
utils::apply_tritensor_elementwise_fn<CTYPE_COMMON, op_name>(
227+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
228+
utils::apply_tritensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
201229
[has_min, has_max](
202-
const CTYPE_COMMON val_in,
203-
const CTYPE_COMMON val_min,
204-
const CTYPE_COMMON val_max) {
205-
CTYPE_COMMON val_out = val_in;
230+
const CTYPE_COMPUTE val_in,
231+
const CTYPE_COMPUTE val_min,
232+
const CTYPE_COMPUTE val_max) {
233+
CTYPE_COMPUTE val_out = val_in;
206234
if (has_min) {
207235
val_out = utils::max_override(val_out, val_min);
208236
}

0 commit comments

Comments
 (0)