Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions kernels/portable/cpu/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
namespace torch {
namespace executor {
namespace native {
namespace impl {

Tensor& add_out(
KernelRuntimeContext& ctx,
Expand Down Expand Up @@ -151,6 +152,47 @@ Tensor& add_scalar_out(
return out;
}

} // namespace impl

Tensor& add_out(
KernelRuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
const Scalar& alpha,
Tensor& out) {
return impl::add_out(ctx, a, b, alpha, out);
}

Tensor& add_scalar_out(
KernelRuntimeContext& ctx,
const Tensor& a,
const Scalar& b,
const Scalar& alpha,
Tensor& out) {
return impl::add_scalar_out(ctx, a, b, alpha, out);
}

namespace utils {

Tensor& add_out(
KernelRuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
const Scalar& alpha,
Tensor& out) {
return impl::add_out(ctx, a, b, alpha, out);
}

Tensor& add_scalar_out(
KernelRuntimeContext& ctx,
const Tensor& a,
const Scalar& b,
const Scalar& alpha,
Tensor& out) {
return impl::add_scalar_out(ctx, a, b, alpha, out);
}

} // namespace utils
} // namespace native
} // namespace executor
} // namespace torch
35 changes: 35 additions & 0 deletions kernels/portable/cpu/op_add.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/runtime/kernel/kernel_includes.h>

#pragma once

namespace torch {
namespace executor {
namespace native {
namespace utils {

Tensor& add_out(
KernelRuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
const Scalar& alpha,
Tensor& out);

Tensor& add_scalar_out(
KernelRuntimeContext& ctx,
const Tensor& a,
const Scalar& b,
const Scalar& alpha,
Tensor& out);

} // namespace utils
} // namespace native
} // namespace executor
} // namespace torch
22 changes: 22 additions & 0 deletions kernels/portable/cpu/op_stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
namespace torch {
namespace executor {
namespace native {
namespace impl {

using Tensor = executorch::aten::Tensor;

Expand Down Expand Up @@ -76,6 +77,27 @@ Tensor& stack_out(
return out;
}

} // namespace impl

Tensor& stack_out(
KernelRuntimeContext& ctx,
executorch::aten::ArrayRef<Tensor> tensors,
int64_t dim,
Tensor& out) {
return impl::stack_out(ctx, tensors, dim, out);
}

namespace utils {

Tensor& stack_out(
KernelRuntimeContext& ctx,
executorch::aten::ArrayRef<Tensor> tensors,
int64_t dim,
Tensor& out) {
return impl::stack_out(ctx, tensors, dim, out);
}

} // namespace utils
} // namespace native
} // namespace executor
} // namespace torch
27 changes: 27 additions & 0 deletions kernels/portable/cpu/op_stack.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/runtime/kernel/kernel_includes.h>

#pragma once

namespace torch {
namespace executor {
namespace native {
namespace utils {

Tensor& stack_out(
KernelRuntimeContext& ctx,
executorch::aten::ArrayRef<Tensor> tensors,
int64_t dim,
Tensor& out);

} // namespace utils
} // namespace native
} // namespace executor
} // namespace torch
38 changes: 36 additions & 2 deletions shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ def get_compiler_optimization_flags():
# App size regressons requires this to be baktraced until I have a better solution
return []

def op_target(name, deps = [], android_deps = [], _allow_third_party_deps = False, _aten_mode_deps = []):
def op_target(name, deps = [], android_deps = [], _allow_third_party_deps = False, _aten_mode_deps = [], exposed_as_util = False):
"""Registers an implementation of an operator overload group.

An operator overload group is a set of operator overloads with a common
Expand Down Expand Up @@ -45,6 +45,8 @@ def op_target(name, deps = [], android_deps = [], _allow_third_party_deps = Fals
from third-party optimization libraries.
_aten_mode_deps: List of deps to add to the cxx_library() when building
for ATen mode.
exposed_as_util: If True, this op has a utils namespace that should be exposed
as a separate library target for reuse by other operators.
"""

# Note that this doesn't actually define the target, but helps register
Expand All @@ -55,6 +57,7 @@ def op_target(name, deps = [], android_deps = [], _allow_third_party_deps = Fals
"name": name,
"_allow_third_party_deps": _allow_third_party_deps,
"_aten_mode_deps": _aten_mode_deps,
"exposed_as_util": exposed_as_util,
}

def _enforce_deps(deps, name, allow_third_party_deps):
Expand Down Expand Up @@ -154,7 +157,7 @@ def define_op_library(name, deps, android_deps, aten_target, _allow_third_party_
link_whole = True,
)

def define_op_target(name, deps, android_deps, is_aten_op, is_et_op = True, _allow_third_party_deps = False, _aten_mode_deps = []):
def define_op_target(name, deps, android_deps, is_aten_op, is_et_op = True, _allow_third_party_deps = False, _aten_mode_deps = [], exposed_as_util = False):
"""Possibly defines cxx_library targets for the named operator group.

Args:
Expand All @@ -166,8 +169,37 @@ def define_op_target(name, deps, android_deps, is_aten_op, is_et_op = True, _all
_allow_third_party_deps: If True, the op is allowed to depend on
third-party deps outside of //executorch. Should only be used by
targets under //executorch/kernels/optimized.
exposed_as_util: If True, this op has a utils namespace that should be exposed
as a separate library target for reuse by other operators.
"""

# If this op has utils, create a separate utils library target
if exposed_as_util:
utils_name = name + "_util"
runtime.cxx_library(
name = utils_name,
srcs = ["{}.cpp".format(name)],
exported_headers = ["{}.h".format(name)],
visibility = [
"//executorch/kernels/portable/...",
"//executorch/kernels/quantized/...",
"//executorch/kernels/optimized/...",
"//executorch/kernels/test/...",
"@EXECUTORCH_CLIENTS",
],
fbandroid_platform_deps = android_deps,
compiler_flags = select({
"DEFAULT": ["-Wno-missing-prototypes"],
"ovr_config//os:windows": [],
}) + (
["-fvisibility=hidden"] if is_xplat() else []
) + get_compiler_optimization_flags(),
deps = [
"//executorch/runtime/kernel:kernel_includes",
] + deps,
force_static = True,
)

# If this is a custom op, define a target that builds it with at::Tensor
# so that it can be imported into a host PyTorch environment for authoring.
if not is_aten_op and True in get_aten_mode_options():
Expand Down Expand Up @@ -226,6 +258,7 @@ ATEN_OPS = (
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
":scalar_utils",
],
exposed_as_util = True,
),
op_target(
name = "op_addmm",
Expand Down Expand Up @@ -1194,6 +1227,7 @@ ATEN_OPS = (
deps = [
"//executorch/kernels/portable/cpu/util:copy_ops_util",
],
exposed_as_util = True,
),
op_target(
name = "op_sub",
Expand Down
Loading