Skip to content

Commit 911ad75

Browse files
[ET][Portable][Build Size] Introduce notion of compute type. Refactor add/clamp/where
Pull Request resolved: #6007 Introduced notion of compute type. We now perform the computation over `CTYPE_COMP` (the compute type) rather than `CTYPE_COMMON` (the common type) All of the occurrences of `CTYPE_COMMON` in elementwise_util.h need to be replaced with `CTYPE_COMP`, to properly reflect that we are dealing with the computation type, and not the common type. But we don't do that in this diff, to facilitate review. The previous `SupportedTensorDtypes::SAME_AS_COMMON` is transformed into `SupportedTensorDtypes::SAME_AS_COMP` and a newer `SupportedTensorDtypes::SAME_AS_COMMON` is written. `SupportedTensorDtypes::SAME_AS_COMMON` should perform the reverse mapping than get_compute_type(). In this case, this means that when `CTYPE_COMP` is anything but float, `SAME_AS_COMMON` is effectively the same as `SAME_AS_COMP`. But when `CTYPE_COMP` is float, `SAME_AS_COMMON` switches over `Float`, `Half` and `BFloat16` Build size reduction: - add: 21K -> 18K - clamp: 28K -> 23K - where: 16K -> 12K ghstack-source-id: 246919711 @exported-using-ghexport Differential Revision: [D63860791](https://our.internmc.facebook.com/intern/diff/D63860791/)
1 parent 2cda076 commit 911ad75

File tree

4 files changed

+238
-129
lines changed

4 files changed

+238
-129
lines changed

kernels/portable/cpu/op_add.cpp

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,36 +24,43 @@ Tensor& add_out(
2424
Tensor& out) {
2525
ET_KERNEL_CHECK(
2626
ctx,
27-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
27+
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
28+
executorch::runtime::tensor_is_realhbbf16_type(b) &&
29+
executorch::runtime::tensor_is_realhbbf16_type(out)),
2830
InvalidArgument,
2931
out);
3032

33+
// Common Dtype
34+
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
35+
36+
// Check Common Dtype
3137
ET_KERNEL_CHECK(
3238
ctx,
33-
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
34-
executorch::runtime::tensor_is_realhbbf16_type(b) &&
35-
executorch::runtime::tensor_is_realhbbf16_type(out)),
39+
(canCast(common_type, out.scalar_type()) &&
40+
check_alpha_type(utils::get_scalar_dtype(alpha), common_type)),
3641
InvalidArgument,
3742
out);
43+
44+
// Check Dim Order
3845
ET_KERNEL_CHECK(
3946
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
4047

41-
ScalarType a_type = a.scalar_type();
42-
ScalarType b_type = b.scalar_type();
43-
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
44-
ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true);
45-
ScalarType out_type = out.scalar_type();
46-
47-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
48+
// Resize
4849
ET_KERNEL_CHECK(
49-
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);
50+
ctx,
51+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
52+
InvalidArgument,
53+
out);
54+
55+
// Compute Dtype
56+
ScalarType compute_type = utils::get_compute_type(common_type);
5057

5158
static constexpr const char op_name[] = "add.out";
5259

53-
ET_SWITCH_REALB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
54-
utils::apply_bitensor_elementwise_fn<CTYPE_COMMON, op_name>(
55-
[alpha](const CTYPE_COMMON val_a, const CTYPE_COMMON val_b) {
56-
CTYPE_COMMON val_alpha = utils::scalar_to<CTYPE_COMMON>(alpha);
60+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
61+
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
62+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
63+
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
5764
return val_a + val_alpha * val_b;
5865
},
5966
a,
@@ -73,52 +80,48 @@ Tensor& add_scalar_out(
7380
const Scalar& b,
7481
const Scalar& alpha,
7582
Tensor& out) {
76-
(void)ctx;
77-
78-
// Resize for dynamic shape
79-
ET_KERNEL_CHECK_MSG(
83+
ET_KERNEL_CHECK(
8084
ctx,
81-
resize_tensor(out, a.sizes()) == Error::Ok,
85+
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
86+
executorch::runtime::tensor_is_realhbbf16_type(out)),
8287
InvalidArgument,
83-
out,
84-
"Failed to resize output tensor.");
88+
out);
8589

90+
// Common Dtype
91+
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
92+
93+
// Check Common Dtype
8694
ET_KERNEL_CHECK(
8795
ctx,
88-
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
89-
executorch::runtime::tensor_is_realhbbf16_type(out)),
96+
(common_type == out.scalar_type() &&
97+
check_alpha_type(utils::get_scalar_dtype(alpha), common_type)),
9098
InvalidArgument,
9199
out);
100+
101+
// Check Dim Order
92102
ET_KERNEL_CHECK(
93103
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
94104

95-
ScalarType a_type = a.scalar_type();
96-
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
97-
ScalarType common_type =
98-
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
99-
ScalarType out_type = out.scalar_type();
100-
101-
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
105+
// Resize
102106
ET_KERNEL_CHECK(
103-
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);
107+
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
104108

105-
if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
106-
common_type = ScalarType::Float;
107-
}
109+
// Compute Dtype
110+
ScalarType compute_type = utils::get_compute_type(common_type);
108111

109112
static constexpr const char op_name[] = "add.Scalar_out";
110113

111-
ET_SWITCH_REALB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
112-
utils::apply_unitensor_elementwise_fn<CTYPE_COMMON, op_name>(
113-
[b, alpha](const CTYPE_COMMON val_a) {
114-
CTYPE_COMMON val_b = utils::scalar_to<CTYPE_COMMON>(b);
115-
CTYPE_COMMON val_alpha = utils::scalar_to<CTYPE_COMMON>(alpha);
114+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
115+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
116+
[b, alpha](const CTYPE_COMPUTE val_a) {
117+
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
118+
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
116119
return val_a + val_alpha * val_b;
117120
},
118121
a,
119122
utils::SupportedTensorDtypes::REALHBBF16,
120123
out,
121-
utils::SupportedTensorDtypes::REALHBBF16);
124+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
122125
});
123126

124127
return out;

kernels/portable/cpu/op_clamp.cpp

Lines changed: 76 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -73,74 +73,90 @@ Tensor& clamp_out(
7373
const exec_aten::optional<Scalar>& min_opt,
7474
const exec_aten::optional<Scalar>& max_opt,
7575
Tensor& out) {
76-
(void)ctx;
76+
ET_KERNEL_CHECK(
77+
ctx,
78+
(executorch::runtime::tensor_is_realhbbf16_type(in) &&
79+
executorch::runtime::tensor_is_realhbbf16_type(out)),
80+
InvalidArgument,
81+
out);
82+
83+
bool has_min = min_opt.has_value();
84+
bool has_max = max_opt.has_value();
7785

7886
ET_KERNEL_CHECK_MSG(
7987
ctx,
80-
resize_tensor(out, in.sizes()) == Error::Ok,
88+
has_min || has_max,
8189
InvalidArgument,
8290
out,
83-
"Failed to resize output tensor.");
84-
85-
ET_KERNEL_CHECK(
86-
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
91+
"At least one of 'min' or 'max' must not be None");
8792

93+
// Input Dtypes
8894
ScalarType in_type = in.scalar_type();
89-
ScalarType min_type = in_type;
90-
ScalarType max_type = in_type;
91-
ScalarType common_type = in_type;
95+
ScalarType min_type =
96+
has_min ? utils::get_scalar_dtype(min_opt.value()) : in_type;
97+
ScalarType max_type =
98+
has_max ? utils::get_scalar_dtype(max_opt.value()) : in_type;
9299
ScalarType out_type = out.scalar_type();
93100

94-
bool has_min = min_opt.has_value();
101+
// Common Dtype
102+
ScalarType common_type = in_type;
95103
if (has_min) {
96-
min_type = utils::get_scalar_dtype(min_opt.value());
97104
common_type = utils::promote_type_with_scalar(common_type, min_opt.value());
105+
}
106+
if (has_max) {
107+
common_type = utils::promote_type_with_scalar(common_type, max_opt.value());
108+
}
109+
110+
// Check Common Dtype
111+
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
112+
113+
// Check Scalar Bounds
114+
if (has_min) {
98115
ET_KERNEL_CHECK(
99116
ctx,
100117
check_bounds(min_opt.value(), min_type, out_type, "minimum"),
101118
InvalidArgument,
102119
out);
103120
}
104-
bool has_max = max_opt.has_value();
105121
if (has_max) {
106-
max_type = utils::get_scalar_dtype(max_opt.value());
107-
common_type = utils::promote_type_with_scalar(common_type, max_opt.value());
108122
ET_KERNEL_CHECK(
109123
ctx,
110124
check_bounds(max_opt.value(), max_type, out_type, "maximum"),
111125
InvalidArgument,
112126
out);
113127
}
114128

115-
ET_KERNEL_CHECK_MSG(
116-
ctx,
117-
has_min || has_max,
118-
InvalidArgument,
119-
out,
120-
"At least one of 'min' or 'max' must not be None");
129+
// Check Dim Order
130+
ET_KERNEL_CHECK(
131+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
121132

122-
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
133+
// Resize
134+
ET_KERNEL_CHECK(
135+
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
136+
137+
// Compute Dtype
138+
ScalarType compute_type = utils::get_compute_type(common_type);
123139

124140
static constexpr const char op_name[] = "clamp.out";
125141

126-
ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
127-
utils::apply_unitensor_elementwise_fn<CTYPE_COMMON, op_name>(
128-
[has_min, min_opt, has_max, max_opt](const CTYPE_COMMON val_in) {
129-
CTYPE_COMMON val_out = val_in;
142+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
143+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
144+
[has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) {
145+
CTYPE_COMPUTE val_out = val_in;
130146
if (has_min) {
131147
val_out = utils::max_override(
132-
val_out, utils::scalar_to<CTYPE_COMMON>(min_opt.value()));
148+
val_out, utils::scalar_to<CTYPE_COMPUTE>(min_opt.value()));
133149
}
134150
if (has_max) {
135151
val_out = utils::min_override(
136-
val_out, utils::scalar_to<CTYPE_COMMON>(max_opt.value()));
152+
val_out, utils::scalar_to<CTYPE_COMPUTE>(max_opt.value()));
137153
}
138154
return val_out;
139155
},
140156
in,
141157
utils::SupportedTensorDtypes::REALHBBF16,
142158
out,
143-
utils::SupportedTensorDtypes::REALHBBF16);
159+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
144160
});
145161

146162
return out;
@@ -152,8 +168,6 @@ Tensor& clamp_tensor_out(
152168
const exec_aten::optional<Tensor>& min_opt,
153169
const exec_aten::optional<Tensor>& max_opt,
154170
Tensor& out) {
155-
(void)ctx;
156-
157171
bool has_min = min_opt.has_value();
158172
bool has_max = max_opt.has_value();
159173

@@ -167,42 +181,54 @@ Tensor& clamp_tensor_out(
167181
const Tensor& min = has_min ? min_opt.value() : in;
168182
const Tensor& max = has_max ? max_opt.value() : in;
169183

184+
ET_KERNEL_CHECK(
185+
ctx,
186+
(executorch::runtime::tensor_is_realhbbf16_type(in) &&
187+
executorch::runtime::tensor_is_realhbbf16_type(min) &&
188+
executorch::runtime::tensor_is_realhbbf16_type(max) &&
189+
executorch::runtime::tensor_is_realhbbf16_type(out)),
190+
InvalidArgument,
191+
out);
192+
193+
// Common Dtype
194+
ScalarType common_type = in.scalar_type();
195+
if (has_min) {
196+
common_type = promoteTypes(common_type, min.scalar_type());
197+
}
198+
if (has_max) {
199+
common_type = promoteTypes(common_type, max.scalar_type());
200+
}
201+
202+
// Check Common Dtype
203+
ET_KERNEL_CHECK(
204+
ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);
205+
206+
// Check Dim Order
170207
ET_KERNEL_CHECK(
171208
ctx,
172209
tensors_have_same_dim_order(in, min, max, out),
173210
InvalidArgument,
174211
out);
175212

213+
// Resize
176214
ET_KERNEL_CHECK(
177215
ctx,
178216
resize_to_broadcast_target_size(in, min, max, out) == Error::Ok,
179217
InvalidArgument,
180218
out);
181219

182-
ScalarType in_type = in.scalar_type();
183-
ScalarType min_type = min.scalar_type();
184-
ScalarType max_type = max.scalar_type();
185-
ScalarType common_type = in_type;
186-
ScalarType out_type = out.scalar_type();
187-
188-
if (has_min) {
189-
common_type = promoteTypes(common_type, min_type, /*half_to_float*/ true);
190-
}
191-
if (has_max) {
192-
common_type = promoteTypes(common_type, max_type, /*half_to_float*/ true);
193-
}
194-
195-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
220+
// Compute Dtype
221+
ScalarType compute_type = utils::get_compute_type(common_type);
196222

197223
static constexpr const char op_name[] = "clamp.Tensor_out";
198224

199-
ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
200-
utils::apply_tritensor_elementwise_fn<CTYPE_COMMON, op_name>(
225+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
226+
utils::apply_tritensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
201227
[has_min, has_max](
202-
const CTYPE_COMMON val_in,
203-
const CTYPE_COMMON val_min,
204-
const CTYPE_COMMON val_max) {
205-
CTYPE_COMMON val_out = val_in;
228+
const CTYPE_COMPUTE val_in,
229+
const CTYPE_COMPUTE val_min,
230+
const CTYPE_COMPUTE val_max) {
231+
CTYPE_COMPUTE val_out = val_in;
206232
if (has_min) {
207233
val_out = utils::max_override(val_out, val_min);
208234
}

0 commit comments

Comments
 (0)