Skip to content

Commit 04ee3b9

Browse files
committed
[Executorch][portable] Split op_mul in impl plus op
Pull Request resolved: pytorch/executorch#6732 This is to facilitate reuse of portable for optimized mul's fallback ghstack-source-id: 252584472 @exported-using-ghexport Differential Revision: [D65628859](https://our.internmc.facebook.com/intern/diff/D65628859/)
1 parent 0c9b35c commit 04ee3b9

File tree

5 files changed

+169
-87
lines changed

5 files changed

+169
-87
lines changed

kernels/portable/cpu/op_mul.cpp

Lines changed: 3 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10-
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11-
#include <executorch/runtime/kernel/kernel_includes.h>
12-
#include <executorch/runtime/platform/assert.h>
9+
#include <executorch/kernels/portable/cpu/op_mul_impl.h>
1310

1411
namespace torch {
1512
namespace executor {
@@ -20,91 +17,15 @@ Tensor& mul_out(
2017
const Tensor& a,
2118
const Tensor& b,
2219
Tensor& out) {
23-
// Common Dtype
24-
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
25-
26-
// Check Common Dtype
27-
ET_KERNEL_CHECK(
28-
ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);
29-
30-
// Check Dim Order
31-
ET_KERNEL_CHECK(
32-
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
33-
34-
// Resize
35-
ET_KERNEL_CHECK(
36-
ctx,
37-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
38-
InvalidArgument,
39-
out);
40-
41-
// Compute Dtype
42-
ScalarType compute_type = utils::get_compute_type(common_type);
43-
44-
// @lint-ignore CLANGTIDY facebook-hte-CArray
45-
static constexpr const char op_name[] = "mul.out";
46-
47-
ET_KERNEL_CHECK(
48-
ctx,
49-
(executorch::runtime::isRealType(compute_type) ||
50-
compute_type == ScalarType::Bool),
51-
InvalidArgument,
52-
out);
53-
54-
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
55-
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
56-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
57-
return val_a * val_b;
58-
},
59-
ctx,
60-
a,
61-
utils::SupportedTensorDtypes::REALHBBF16,
62-
b,
63-
utils::SupportedTensorDtypes::REALHBBF16,
64-
out,
65-
utils::SupportedTensorDtypes::REALHBBF16);
66-
});
67-
68-
return out;
20+
return mul_out_impl(ctx, a, b, out);
6921
}
7022

7123
Tensor& mul_scalar_out(
7224
KernelRuntimeContext& ctx,
7325
const Tensor& a,
7426
const Scalar& b,
7527
Tensor& out) {
76-
// Common Dtype
77-
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
78-
79-
// Check Common Dtype
80-
ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out);
81-
82-
// Check Dim Order
83-
ET_KERNEL_CHECK(
84-
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
85-
86-
// Resize
87-
ET_KERNEL_CHECK(
88-
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
89-
90-
// Compute Dtype
91-
ScalarType compute_type = utils::get_compute_type(common_type);
92-
93-
// @lint-ignore CLANGTIDY facebook-hte-CArray
94-
static constexpr const char op_name[] = "mul.Scalar_out";
95-
96-
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
97-
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
98-
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
99-
[val_b](const CTYPE_COMPUTE val_a) { return val_a * val_b; },
100-
ctx,
101-
a,
102-
utils::SupportedTensorDtypes::REALHBBF16,
103-
out,
104-
utils::SupportedTensorDtypes::SAME_AS_COMMON);
105-
});
106-
107-
return out;
28+
return mul_scalar_out_impl(ctx, a, b, out);
10829
}
10930

11031
} // namespace native
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <executorch/runtime/platform/assert.h>
13+
14+
#include <executorch/kernels/portable/cpu/op_mul_impl.h>
15+
16+
namespace torch {
17+
namespace executor {
18+
namespace native {
19+
20+
Tensor& mul_out_impl(
21+
KernelRuntimeContext& ctx,
22+
const Tensor& a,
23+
const Tensor& b,
24+
Tensor& out) {
25+
// Common Dtype
26+
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
27+
28+
// Check Common Dtype
29+
ET_KERNEL_CHECK(
30+
ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);
31+
32+
// Check Dim Order
33+
ET_KERNEL_CHECK(
34+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
35+
36+
// Resize
37+
ET_KERNEL_CHECK(
38+
ctx,
39+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
40+
InvalidArgument,
41+
out);
42+
43+
// Compute Dtype
44+
ScalarType compute_type = utils::get_compute_type(common_type);
45+
46+
// @lint-ignore CLANGTIDY facebook-hte-CArray
47+
static constexpr const char op_name[] = "mul.out";
48+
49+
ET_KERNEL_CHECK(
50+
ctx,
51+
(executorch::runtime::isRealType(compute_type) ||
52+
compute_type == ScalarType::Bool),
53+
InvalidArgument,
54+
out);
55+
56+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
57+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
58+
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
59+
return val_a * val_b;
60+
},
61+
ctx,
62+
a,
63+
utils::SupportedTensorDtypes::REALHBBF16,
64+
b,
65+
utils::SupportedTensorDtypes::REALHBBF16,
66+
out,
67+
utils::SupportedTensorDtypes::REALHBBF16);
68+
});
69+
70+
return out;
71+
}
72+
73+
Tensor& mul_scalar_out_impl(
74+
KernelRuntimeContext& ctx,
75+
const Tensor& a,
76+
const Scalar& b,
77+
Tensor& out) {
78+
// Common Dtype
79+
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
80+
81+
// Check Common Dtype
82+
ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out);
83+
84+
// Check Dim Order
85+
ET_KERNEL_CHECK(
86+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
87+
88+
// Resize
89+
ET_KERNEL_CHECK(
90+
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
91+
92+
// Compute Dtype
93+
ScalarType compute_type = utils::get_compute_type(common_type);
94+
95+
// @lint-ignore CLANGTIDY facebook-hte-CArray
96+
static constexpr const char op_name[] = "mul.Scalar_out";
97+
98+
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
99+
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
100+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
101+
[val_b](const CTYPE_COMPUTE val_a) { return val_a * val_b; },
102+
ctx,
103+
a,
104+
utils::SupportedTensorDtypes::REALHBBF16,
105+
out,
106+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
107+
});
108+
109+
return out;
110+
}
111+
112+
} // namespace native
113+
} // namespace executor
114+
} // namespace torch

kernels/portable/cpu/op_mul_impl.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/kernel/kernel_includes.h>
10+
11+
namespace torch {
12+
namespace executor {
13+
namespace native {
14+
15+
Tensor& mul_out_impl(
16+
KernelRuntimeContext& ctx,
17+
const Tensor& a,
18+
const Tensor& b,
19+
Tensor& out);
20+
21+
Tensor& mul_scalar_out_impl(
22+
KernelRuntimeContext& ctx,
23+
const Tensor& a,
24+
const Scalar& b,
25+
Tensor& out);
26+
27+
} // namespace native
28+
} // namespace executor
29+
} // namespace torch

kernels/portable/cpu/targets.bzl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,26 @@ def define_common_targets():
111111
],
112112
)
113113

114+
runtime.cxx_library(
115+
name = "op_mul_impl",
116+
srcs = ["op_mul_impl.cpp"],
117+
exported_headers = ["op_mul_impl.h"],
118+
visibility = [
119+
"//executorch/kernels/portable/cpu/...",
120+
"//executorch/kernels/optimized/cpu/...",
121+
"//executorch/kernels/portable/test/...",
122+
"@EXECUTORCH_CLIENTS",
123+
],
124+
exported_deps = [
125+
"//executorch/kernels/portable/cpu/util:broadcast_util",
126+
"//executorch/kernels/portable/cpu/util:dtype_util",
127+
"//executorch/kernels/portable/cpu/util:elementwise_util",
128+
"//executorch/kernels/portable/cpu/util:math_util",
129+
"//executorch/kernels/portable/cpu:scalar_utils",
130+
"//executorch/runtime/kernel:kernel_includes",
131+
],
132+
)
133+
114134
# The following will not participate in dtype selective build because
115135
# they are refactored such to be used in optimized op implementations as well
116136
# and we have not enabled selective build for optimized ops.
@@ -123,6 +143,7 @@ def define_common_targets():
123143
name = "all_impl_deps",
124144
deps = [
125145
"//executorch/kernels/portable/cpu:op_div_impl",
146+
"//executorch/kernels/portable/cpu:op_mul_impl",
126147
],
127148
visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"],
128149
)

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -842,10 +842,7 @@ ATEN_OPS = (
842842
op_target(
843843
name = "op_mul",
844844
deps = [
845-
"//executorch/kernels/portable/cpu/util:broadcast_util",
846-
"//executorch/kernels/portable/cpu/util:dtype_util",
847-
"//executorch/kernels/portable/cpu/util:elementwise_util",
848-
":scalar_utils",
845+
"//executorch/kernels/portable/cpu:op_mul_impl",
849846
],
850847
),
851848
op_target(
@@ -1271,4 +1268,4 @@ def portable_source_list():
12711268

12721269
def portable_header_list():
12731270
"""All the header file names from //executorch/kernels/portable/cpu/"""
1274-
return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h"]
1271+
return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h", "op_mul_impl.h"]

0 commit comments

Comments
 (0)