Skip to content

Commit 713d375

Browse files
Update on "[ET][Portable][Build Size] Move to dtype_utils"
Differential Revision: [D63994875](https://our.internmc.facebook.com/intern/diff/D63994875/) [ghstack-poisoned]
2 parents 00b7f3d + 42d1284 commit 713d375

File tree

6 files changed

+18
-17
lines changed

6 files changed

+18
-17
lines changed

kernels/portable/cpu/op_div.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ Tensor& div_out_mode(
133133
if (mode_is_trunc) {
134134
value = std::trunc(value);
135135
} else {
136-
// We established above that the mode is either trunc or floor, so it must be floor.
136+
// We established above that the mode is either trunc or floor, so
137+
// it must be floor.
137138
value = utils::floor_divide(val_a, val_b);
138139
}
139140
return value;
@@ -185,9 +186,7 @@ Tensor& div_scalar_out(
185186
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
186187
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
187188
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
188-
[val_b](const CTYPE_COMPUTE val_a) {
189-
return val_a / val_b;
190-
},
189+
[val_b](const CTYPE_COMPUTE val_a) {return val_a / val_b;},
191190
ctx,
192191
a,
193192
utils::SupportedTensorDtypes::REALHBBF16,

kernels/portable/cpu/op_mul.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ Tensor& mul_out(
4343

4444
static constexpr const char op_name[] = "mul.out";
4545

46+
ET_KERNEL_CHECK(
47+
ctx,
48+
(executorch::runtime::isRealType(compute_type) || compute_type == ScalarType::Bool),
49+
InvalidArgument,
50+
out);
51+
4652
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
4753
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
4854
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
@@ -87,9 +93,7 @@ Tensor& mul_scalar_out(
8793
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
8894
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
8995
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
90-
[val_b](const CTYPE_COMPUTE val_a) {
91-
return val_a * val_b;
92-
},
96+
[val_b](const CTYPE_COMPUTE val_a) {return val_a * val_b;},
9397
ctx,
9498
a,
9599
utils::SupportedTensorDtypes::REALHBBF16,

kernels/portable/cpu/op_pow.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,7 @@ Tensor& pow_Tensor_Scalar_out(
103103
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
104104
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
105105
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
106-
[val_b](const CTYPE_COMPUTE val_a) {
107-
return std::pow(val_a, val_b);
108-
},
106+
[val_b](const CTYPE_COMPUTE val_a) {return std::pow(val_a, val_b);},
109107
ctx,
110108
a,
111109
utils::SupportedTensorDtypes::REALHBBF16,
@@ -151,9 +149,7 @@ Tensor& pow_Scalar_out(
151149
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
152150
const CTYPE_COMPUTE val_a = utils::scalar_to<CTYPE_COMPUTE>(a);
153151
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
154-
[val_a](const CTYPE_COMPUTE val_b) {
155-
return std::pow(val_a, val_b);
156-
},
152+
[val_a](const CTYPE_COMPUTE val_b) {return std::pow(val_a, val_b);},
157153
ctx,
158154
b,
159155
utils::SupportedTensorDtypes::REALHBBF16,

kernels/portable/cpu/op_rsub.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ Tensor& rsub_scalar_out(
5353
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
5454
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
5555
[val_b, val_alpha](const CTYPE_COMPUTE val_a) {
56-
5756
return val_b - val_alpha * val_a;
5857
},
5958
ctx,

kernels/portable/cpu/pattern/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def define_common_targets():
1515
"//executorch/kernels/portable/cpu/pattern:pattern",
1616
"//executorch/kernels/portable/cpu/pattern:bitwise_op",
1717
"//executorch/kernels/portable/cpu/pattern:comparison_op",
18+
"//executorch/kernels/portable/cpu/pattern:logical_op"
1819
],
1920
visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"],
2021
)

kernels/portable/test/op_mul_test.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,14 @@ TEST_F(OpMulOutKernelTest, UnhandledDtypeDies) {
4949
std::vector<exec_aten::qint8> b_data(a_data);
5050
std::vector<exec_aten::qint8> out_data(a_data);
5151

52+
std::vector<exec_aten::DimOrderType> dim_order = {0, 1};
53+
5254
auto a_impl = torch::executor::TensorImpl(
53-
ScalarType::QInt8, 2, sizes.data(), a_data.data());
55+
ScalarType::QInt8, 2, sizes.data(), a_data.data(), dim_order.data());
5456
auto b_impl = torch::executor::TensorImpl(
55-
ScalarType::QInt8, 2, sizes.data(), b_data.data());
57+
ScalarType::QInt8, 2, sizes.data(), b_data.data(), dim_order.data());
5658
auto out_impl = torch::executor::TensorImpl(
57-
ScalarType::QInt8, 2, sizes.data(), out_data.data());
59+
ScalarType::QInt8, 2, sizes.data(), out_data.data(), dim_order.data());
5860

5961
// Two input tensors.
6062
Tensor a(&a_impl);

0 commit comments

Comments
 (0)