Skip to content

Commit f8d182b

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Check tensor dtype inside elementwise_utils (#6008)
Summary: Pull Request resolved: #6008 ghstack-source-id: 246985130 exported-using-ghexport Reviewed By: swolchok Differential Revision: D64052137 fbshipit-source-id: 7a72b3efa2bcdbd8412a03dd0cbb961ea56b6d0d
1 parent 607d4a3 commit f8d182b

File tree

7 files changed

+103
-44
lines changed

7 files changed

+103
-44
lines changed

kernels/portable/cpu/op_add.cpp

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,6 @@ Tensor& add_out(
2222
const Tensor& b,
2323
const Scalar& alpha,
2424
Tensor& out) {
25-
ET_KERNEL_CHECK(
26-
ctx,
27-
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
28-
executorch::runtime::tensor_is_realhbbf16_type(b) &&
29-
executorch::runtime::tensor_is_realhbbf16_type(out)),
30-
InvalidArgument,
31-
out);
32-
3325
// Common Dtype
3426
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
3527

@@ -64,6 +56,7 @@ Tensor& add_out(
6456
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
6557
return val_a + val_alpha * val_b;
6658
},
59+
ctx,
6760
a,
6861
utils::SupportedTensorDtypes::REALHBBF16,
6962
b,
@@ -81,13 +74,6 @@ Tensor& add_scalar_out(
8174
const Scalar& b,
8275
const Scalar& alpha,
8376
Tensor& out) {
84-
ET_KERNEL_CHECK(
85-
ctx,
86-
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
87-
executorch::runtime::tensor_is_realhbbf16_type(out)),
88-
InvalidArgument,
89-
out);
90-
9177
// Common Dtype
9278
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
9379

@@ -120,6 +106,7 @@ Tensor& add_scalar_out(
120106
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
121107
return val_a + val_alpha * val_b;
122108
},
109+
ctx,
123110
a,
124111
utils::SupportedTensorDtypes::REALHBBF16,
125112
out,

kernels/portable/cpu/op_clamp.cpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,6 @@ Tensor& clamp_out(
7373
const exec_aten::optional<Scalar>& min_opt,
7474
const exec_aten::optional<Scalar>& max_opt,
7575
Tensor& out) {
76-
ET_KERNEL_CHECK(
77-
ctx,
78-
(executorch::runtime::tensor_is_realhbbf16_type(in) &&
79-
executorch::runtime::tensor_is_realhbbf16_type(out)),
80-
InvalidArgument,
81-
out);
82-
8376
bool has_min = min_opt.has_value();
8477
bool has_max = max_opt.has_value();
8578

@@ -154,6 +147,7 @@ Tensor& clamp_out(
154147
}
155148
return val_out;
156149
},
150+
ctx,
157151
in,
158152
utils::SupportedTensorDtypes::REALHBBF16,
159153
out,
@@ -182,15 +176,6 @@ Tensor& clamp_tensor_out(
182176
const Tensor& min = has_min ? min_opt.value() : in;
183177
const Tensor& max = has_max ? max_opt.value() : in;
184178

185-
ET_KERNEL_CHECK(
186-
ctx,
187-
(executorch::runtime::tensor_is_realhbbf16_type(in) &&
188-
executorch::runtime::tensor_is_realhbbf16_type(min) &&
189-
executorch::runtime::tensor_is_realhbbf16_type(max) &&
190-
executorch::runtime::tensor_is_realhbbf16_type(out)),
191-
InvalidArgument,
192-
out);
193-
194179
// Common Dtype
195180
ScalarType common_type = in.scalar_type();
196181
if (has_min) {
@@ -239,6 +224,7 @@ Tensor& clamp_tensor_out(
239224
}
240225
return val_out;
241226
},
227+
ctx,
242228
in,
243229
utils::SupportedTensorDtypes::REALHBBF16,
244230
min,

kernels/portable/cpu/op_where.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,6 @@ Tensor& where_out(
1919
const Tensor& a,
2020
const Tensor& b,
2121
Tensor& out) {
22-
ET_KERNEL_CHECK(
23-
ctx,
24-
((cond.scalar_type() == ScalarType::Bool ||
25-
cond.scalar_type() == ScalarType::Byte) &&
26-
executorch::runtime::tensor_is_realhbbf16_type(a) &&
27-
executorch::runtime::tensor_is_realhbbf16_type(b) &&
28-
executorch::runtime::tensor_is_realhbbf16_type(out)),
29-
InvalidArgument,
30-
out);
31-
3222
// Common Dtype
3323
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
3424

@@ -57,6 +47,7 @@ Tensor& where_out(
5747
[](const CTYPE_COMPUTE val_a,
5848
const CTYPE_COMPUTE val_b,
5949
const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; },
50+
ctx,
6051
a,
6152
utils::SupportedTensorDtypes::REALHBBF16,
6253
b,
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
10+
11+
namespace torch {
12+
namespace executor {
13+
namespace native {
14+
namespace utils {
15+
namespace internal {
16+
17+
bool check_tensor_dtype(
18+
const Tensor t,
19+
SupportedTensorDtypes dtypes,
20+
const ScalarType compute_type) {
21+
switch (dtypes) {
22+
case SupportedTensorDtypes::REALHBBF16:
23+
return executorch::runtime::tensor_is_realhbbf16_type(t);
24+
case SupportedTensorDtypes::BOOL_OR_BYTE:
25+
return (
26+
executorch::runtime::tensor_is_type(t, ScalarType::Bool) ||
27+
executorch::runtime::tensor_is_type(t, ScalarType::Byte));
28+
case SupportedTensorDtypes::SAME_AS_COMPUTE:
29+
return executorch::runtime::tensor_is_type(t, compute_type);
30+
case SupportedTensorDtypes::SAME_AS_COMMON: {
31+
if (compute_type == ScalarType::Float) {
32+
return (
33+
executorch::runtime::tensor_is_type(t, ScalarType::Float) ||
34+
executorch::runtime::tensor_is_type(t, ScalarType::Half) ||
35+
executorch::runtime::tensor_is_type(t, ScalarType::BFloat16));
36+
} else {
37+
return executorch::runtime::tensor_is_type(t, compute_type);
38+
}
39+
}
40+
}
41+
ET_CHECK(false);
42+
return false;
43+
}
44+
45+
} // namespace internal
46+
} // namespace utils
47+
} // namespace native
48+
} // namespace executor
49+
} // namespace torch

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,15 +229,29 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
229229
return nullptr;
230230
}
231231

232+
bool check_tensor_dtype(
233+
const Tensor t,
234+
SupportedTensorDtypes dtypes,
235+
const ScalarType compute_type);
236+
232237
} // namespace internal
233238

234239
template <typename CTYPE_COMMON, const char* op_name, typename Op>
235240
inline void apply_unitensor_elementwise_fn(
236241
const Op& compute_fun,
242+
KernelRuntimeContext& ctx,
237243
const Tensor& a,
238244
SupportedTensorDtypes a_dtypes,
239245
const Tensor& out,
240246
SupportedTensorDtypes out_dtypes) {
247+
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
248+
249+
ET_KERNEL_CHECK(
250+
ctx,
251+
(internal::check_tensor_dtype(a, a_dtypes, compute_type) &&
252+
internal::check_tensor_dtype(out, out_dtypes, compute_type)),
253+
InvalidArgument, );
254+
241255
const auto load_a_to_common =
242256
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
243257
const auto store_common_to_out =
@@ -263,12 +277,22 @@ inline void apply_unitensor_elementwise_fn(
263277
template <typename CTYPE_COMMON, const char* op_name, typename Op>
264278
inline void apply_bitensor_elementwise_fn(
265279
const Op& compute_fun,
280+
KernelRuntimeContext& ctx,
266281
const Tensor& a,
267282
SupportedTensorDtypes a_dtypes,
268283
const Tensor& b,
269284
SupportedTensorDtypes b_dtypes,
270285
const Tensor& out,
271286
SupportedTensorDtypes out_dtypes) {
287+
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
288+
289+
ET_KERNEL_CHECK(
290+
ctx,
291+
(internal::check_tensor_dtype(a, a_dtypes, compute_type) &&
292+
internal::check_tensor_dtype(b, b_dtypes, compute_type) &&
293+
internal::check_tensor_dtype(out, out_dtypes, compute_type)),
294+
InvalidArgument, );
295+
272296
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
273297
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
274298
const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted);
@@ -312,9 +336,9 @@ inline void apply_bitensor_elementwise_fn(
312336
}
313337

314338
/**
315-
* Useful for tri-tensor elementwise operators. For each element of the inputs,
316-
* perform a computation and write to the corresponding element of the output.
317-
* Tensor broadcasting is applied wherever it is required.
339+
* Useful for tri-tensor elementwise operators. For each element of the
340+
* inputs, perform a computation and write to the corresponding element of the
341+
* output. Tensor broadcasting is applied wherever it is required.
318342
*
319343
* In order to mitigate build time cost (straightforwardly |CTYPE_A| *
320344
* |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun
@@ -334,6 +358,7 @@ inline void apply_bitensor_elementwise_fn(
334358
template <typename CTYPE_COMMON, const char* op_name, typename Op>
335359
inline void apply_tritensor_elementwise_fn(
336360
const Op& compute_fun,
361+
KernelRuntimeContext& ctx,
337362
const Tensor& a,
338363
SupportedTensorDtypes a_dtypes,
339364
const Tensor& b,
@@ -342,6 +367,16 @@ inline void apply_tritensor_elementwise_fn(
342367
SupportedTensorDtypes c_dtypes,
343368
const Tensor& out,
344369
SupportedTensorDtypes out_dtypes) {
370+
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
371+
372+
ET_KERNEL_CHECK(
373+
ctx,
374+
(internal::check_tensor_dtype(a, a_dtypes, compute_type) &&
375+
internal::check_tensor_dtype(b, b_dtypes, compute_type) &&
376+
internal::check_tensor_dtype(c, c_dtypes, compute_type) &&
377+
internal::check_tensor_dtype(out, out_dtypes, compute_type)),
378+
InvalidArgument, );
379+
345380
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
346381
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
347382
const bool c_is_broadcasted = !out.sizes().equals(c.sizes());

kernels/portable/cpu/util/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def define_common_targets():
8080

8181
runtime.cxx_library(
8282
name = "elementwise_util",
83+
srcs = ["elementwise_util.cpp"],
8384
exported_headers = [
8485
"elementwise_util.h",
8586
],

runtime/core/exec_aten/util/tensor_util.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,16 @@ inline bool tensor_is_bool_type(exec_aten::Tensor t) {
469469
return true;
470470
}
471471

472+
inline bool tensor_is_type(exec_aten::Tensor t, exec_aten::ScalarType dtype) {
473+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
474+
t.scalar_type() == dtype,
475+
"Expected to find %s type, but tensor has type %s",
476+
torch::executor::toString(dtype),
477+
torch::executor::toString(t.scalar_type()));
478+
479+
return true;
480+
}
481+
472482
inline bool tensor_is_integral_type(
473483
exec_aten::Tensor t,
474484
bool includeBool = false) {

0 commit comments

Comments
 (0)