Skip to content

Commit bc5186c

Browse files
Expose portable ops as utils (add/stack)
Differential Revision: D79654142 Pull Request resolved: #13200
1 parent fff2090 commit bc5186c

File tree

5 files changed

+162
-2
lines changed

5 files changed

+162
-2
lines changed

kernels/portable/cpu/op_add.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
namespace torch {
1616
namespace executor {
1717
namespace native {
18+
namespace impl {
1819

1920
Tensor& add_out(
2021
KernelRuntimeContext& ctx,
@@ -151,6 +152,47 @@ Tensor& add_scalar_out(
151152
return out;
152153
}
153154

155+
} // namespace impl
156+
157+
Tensor& add_out(
158+
KernelRuntimeContext& ctx,
159+
const Tensor& a,
160+
const Tensor& b,
161+
const Scalar& alpha,
162+
Tensor& out) {
163+
return impl::add_out(ctx, a, b, alpha, out);
164+
}
165+
166+
Tensor& add_scalar_out(
167+
KernelRuntimeContext& ctx,
168+
const Tensor& a,
169+
const Scalar& b,
170+
const Scalar& alpha,
171+
Tensor& out) {
172+
return impl::add_scalar_out(ctx, a, b, alpha, out);
173+
}
174+
175+
namespace utils {
176+
177+
Tensor& add_out(
178+
KernelRuntimeContext& ctx,
179+
const Tensor& a,
180+
const Tensor& b,
181+
const Scalar& alpha,
182+
Tensor& out) {
183+
return impl::add_out(ctx, a, b, alpha, out);
184+
}
185+
186+
Tensor& add_scalar_out(
187+
KernelRuntimeContext& ctx,
188+
const Tensor& a,
189+
const Scalar& b,
190+
const Scalar& alpha,
191+
Tensor& out) {
192+
return impl::add_scalar_out(ctx, a, b, alpha, out);
193+
}
194+
195+
} // namespace utils
154196
} // namespace native
155197
} // namespace executor
156198
} // namespace torch

kernels/portable/cpu/op_add.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/kernel/kernel_includes.h>
10+
11+
#pragma once
12+
13+
namespace torch {
14+
namespace executor {
15+
namespace native {
16+
namespace utils {
17+
18+
Tensor& add_out(
19+
KernelRuntimeContext& ctx,
20+
const Tensor& a,
21+
const Tensor& b,
22+
const Scalar& alpha,
23+
Tensor& out);
24+
25+
Tensor& add_scalar_out(
26+
KernelRuntimeContext& ctx,
27+
const Tensor& a,
28+
const Scalar& b,
29+
const Scalar& alpha,
30+
Tensor& out);
31+
32+
} // namespace utils
33+
} // namespace native
34+
} // namespace executor
35+
} // namespace torch

kernels/portable/cpu/op_stack.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
namespace torch {
1515
namespace executor {
1616
namespace native {
17+
namespace impl {
1718

1819
using Tensor = executorch::aten::Tensor;
1920

@@ -76,6 +77,27 @@ Tensor& stack_out(
7677
return out;
7778
}
7879

80+
} // namespace impl
81+
82+
Tensor& stack_out(
83+
KernelRuntimeContext& ctx,
84+
executorch::aten::ArrayRef<Tensor> tensors,
85+
int64_t dim,
86+
Tensor& out) {
87+
return impl::stack_out(ctx, tensors, dim, out);
88+
}
89+
90+
namespace utils {
91+
92+
Tensor& stack_out(
93+
KernelRuntimeContext& ctx,
94+
executorch::aten::ArrayRef<Tensor> tensors,
95+
int64_t dim,
96+
Tensor& out) {
97+
return impl::stack_out(ctx, tensors, dim, out);
98+
}
99+
100+
} // namespace utils
79101
} // namespace native
80102
} // namespace executor
81103
} // namespace torch

kernels/portable/cpu/op_stack.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/kernel/kernel_includes.h>
10+
11+
#pragma once
12+
13+
namespace torch {
14+
namespace executor {
15+
namespace native {
16+
namespace utils {
17+
18+
Tensor& stack_out(
19+
KernelRuntimeContext& ctx,
20+
executorch::aten::ArrayRef<Tensor> tensors,
21+
int64_t dim,
22+
Tensor& out);
23+
24+
} // namespace utils
25+
} // namespace native
26+
} // namespace executor
27+
} // namespace torch

shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ def get_compiler_optimization_flags():
55
# App size regressons requires this to be baktraced until I have a better solution
66
return []
77

8-
def op_target(name, deps = [], android_deps = [], _allow_third_party_deps = False, _aten_mode_deps = []):
8+
def op_target(name, deps = [], android_deps = [], _allow_third_party_deps = False, _aten_mode_deps = [], exposed_as_util = False):
99
"""Registers an implementation of an operator overload group.
1010
1111
An operator overload group is a set of operator overloads with a common
@@ -45,6 +45,8 @@ def op_target(name, deps = [], android_deps = [], _allow_third_party_deps = Fals
4545
from third-party optimization libraries.
4646
_aten_mode_deps: List of deps to add to the cxx_library() when building
4747
for ATen mode.
48+
exposed_as_util: If True, this op has a utils namespace that should be exposed
49+
as a separate library target for reuse by other operators.
4850
"""
4951

5052
# Note that this doesn't actually define the target, but helps register
@@ -55,6 +57,7 @@ def op_target(name, deps = [], android_deps = [], _allow_third_party_deps = Fals
5557
"name": name,
5658
"_allow_third_party_deps": _allow_third_party_deps,
5759
"_aten_mode_deps": _aten_mode_deps,
60+
"exposed_as_util": exposed_as_util,
5861
}
5962

6063
def _enforce_deps(deps, name, allow_third_party_deps):
@@ -154,7 +157,7 @@ def define_op_library(name, deps, android_deps, aten_target, _allow_third_party_
154157
link_whole = True,
155158
)
156159

157-
def define_op_target(name, deps, android_deps, is_aten_op, is_et_op = True, _allow_third_party_deps = False, _aten_mode_deps = []):
160+
def define_op_target(name, deps, android_deps, is_aten_op, is_et_op = True, _allow_third_party_deps = False, _aten_mode_deps = [], exposed_as_util = False):
158161
"""Possibly defines cxx_library targets for the named operator group.
159162
160163
Args:
@@ -166,8 +169,37 @@ def define_op_target(name, deps, android_deps, is_aten_op, is_et_op = True, _all
166169
_allow_third_party_deps: If True, the op is allowed to depend on
167170
third-party deps outside of //executorch. Should only be used by
168171
targets under //executorch/kernels/optimized.
172+
exposed_as_util: If True, this op has a utils namespace that should be exposed
173+
as a separate library target for reuse by other operators.
169174
"""
170175

176+
# If this op has utils, create a separate utils library target
177+
if exposed_as_util:
178+
utils_name = name + "_util"
179+
runtime.cxx_library(
180+
name = utils_name,
181+
srcs = ["{}.cpp".format(name)],
182+
exported_headers = ["{}.h".format(name)],
183+
visibility = [
184+
"//executorch/kernels/portable/...",
185+
"//executorch/kernels/quantized/...",
186+
"//executorch/kernels/optimized/...",
187+
"//executorch/kernels/test/...",
188+
"@EXECUTORCH_CLIENTS",
189+
],
190+
fbandroid_platform_deps = android_deps,
191+
compiler_flags = select({
192+
"DEFAULT": ["-Wno-missing-prototypes"],
193+
"ovr_config//os:windows": [],
194+
}) + (
195+
["-fvisibility=hidden"] if is_xplat() else []
196+
) + get_compiler_optimization_flags(),
197+
deps = [
198+
"//executorch/runtime/kernel:kernel_includes",
199+
] + deps,
200+
force_static = True,
201+
)
202+
171203
# If this is a custom op, define a target that builds it with at::Tensor
172204
# so that it can be imported into a host PyTorch environment for authoring.
173205
if not is_aten_op and True in get_aten_mode_options():
@@ -226,6 +258,7 @@ ATEN_OPS = (
226258
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
227259
":scalar_utils",
228260
],
261+
exposed_as_util = True,
229262
),
230263
op_target(
231264
name = "op_addmm",
@@ -1194,6 +1227,7 @@ ATEN_OPS = (
11941227
deps = [
11951228
"//executorch/kernels/portable/cpu/util:copy_ops_util",
11961229
],
1230+
exposed_as_util = True,
11971231
),
11981232
op_target(
11991233
name = "op_sub",

0 commit comments

Comments
 (0)