Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions kernels/portable/cpu/op_logical_and.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/portable/cpu/pattern/pattern.h>
#include <executorch/kernels/portable/cpu/pattern/logical_op.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <cmath>

Expand All @@ -26,8 +26,8 @@ Tensor& logical_and_out(
const Tensor& a,
const Tensor& b,
Tensor& out) {
return internal::binary_ufunc_realb_realb_to_realb_logical(
logical_and, ctx, a, b, out);
static constexpr const char op_name[] = "logical_and.out";
return internal::logical_tensor_out<op_name>(logical_and, ctx, a, b, out);
}

} // namespace native
Expand Down
6 changes: 3 additions & 3 deletions kernels/portable/cpu/op_logical_or.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/portable/cpu/pattern/pattern.h>
#include <executorch/kernels/portable/cpu/pattern/logical_op.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <cmath>

Expand All @@ -26,8 +26,8 @@ Tensor& logical_or_out(
const Tensor& a,
const Tensor& b,
Tensor& out) {
return internal::binary_ufunc_realb_realb_to_realb_logical(
logical_or, ctx, a, b, out);
static constexpr const char op_name[] = "logical_or.out";
return internal::logical_tensor_out<op_name>(logical_or, ctx, a, b, out);
}

} // namespace native
Expand Down
6 changes: 3 additions & 3 deletions kernels/portable/cpu/op_logical_xor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/portable/cpu/pattern/pattern.h>
#include <executorch/kernels/portable/cpu/pattern/logical_op.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <cmath>

Expand All @@ -26,8 +26,8 @@ Tensor& logical_xor_out(
const Tensor& a,
const Tensor& b,
Tensor& out) {
return internal::binary_ufunc_realb_realb_to_realb_logical(
logical_xor, ctx, a, b, out);
static constexpr const char op_name[] = "logical_xor.out";
return internal::logical_tensor_out<op_name>(logical_xor, ctx, a, b, out);
}

} // namespace native
Expand Down

This file was deleted.

53 changes: 53 additions & 0 deletions kernels/portable/cpu/pattern/logical_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {
namespace native {
namespace internal {

/**
* Implements an op pattern for ops that take two broadcastable input tensors
* and performs an element-wise binary logical operation `fn`.
*/
template <const char* op_name>
Tensor& logical_tensor_out(
bool (*fn)(bool, bool),
KernelRuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
Tensor& out) {
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);

ET_KERNEL_CHECK(
ctx,
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
InvalidArgument,
out);

utils::apply_bitensor_elementwise_fn<bool, op_name>(
fn,
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);

return out;
}

} // namespace internal
} // namespace native
} // namespace executor
} // namespace torch
13 changes: 0 additions & 13 deletions kernels/portable/cpu/pattern/pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,6 @@ Tensor& unary_ufunc_realhbbf16_to_floathbf16(
const Tensor& in,
Tensor& out);

/**
* Implements an op pattern for ops that take two broadcastable input tensors
* of any realb dtype, no additional arguments, performs an element-wise binary
* logical operation, and outputs a realb tensor. The function fn specifies the
* binary logical operation which is applied to the input tensors element-wise.
*/
Tensor& binary_ufunc_realb_realb_to_realb_logical(
bool (*fn)(bool, bool),
KernelRuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
Tensor& out);

} // namespace internal
} // namespace native
} // namespace executor
Expand Down
10 changes: 9 additions & 1 deletion kernels/portable/cpu/pattern/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,21 @@ def define_common_targets():
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."],
)

runtime.cxx_library(
name = "logical_op",
exported_headers = [
"logical_op.h",
],
compiler_flags = [],
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."],
)

runtime.cxx_library(
name = "pattern",
srcs = [
"unary_ufunc_realhb_to_bool.cpp",
"unary_ufunc_realhbbf16_to_floathbf16.cpp",
"unary_ufunc_realh.cpp",
"binary_ufunc_realb_realb_to_realb_logical.cpp",
],
exported_headers = [
"pattern.h",
Expand Down
15 changes: 12 additions & 3 deletions shim/xplat/executorch/kernels/portable/op_registration_util.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,10 @@ ATEN_OPS = (
op_target(
name = "op_logical_and",
deps = [
"//executorch/kernels/portable/cpu/pattern:pattern",
":scalar_utils",
"//executorch/kernels/portable/cpu/pattern:logical_op",
"//executorch/kernels/portable/cpu/util:broadcast_util",
"//executorch/kernels/portable/cpu/util:elementwise_util",
],
),
op_target(
Expand All @@ -712,13 +715,19 @@ ATEN_OPS = (
op_target(
name = "op_logical_or",
deps = [
"//executorch/kernels/portable/cpu/pattern:pattern",
":scalar_utils",
"//executorch/kernels/portable/cpu/pattern:logical_op",
"//executorch/kernels/portable/cpu/util:broadcast_util",
"//executorch/kernels/portable/cpu/util:elementwise_util",
],
),
op_target(
name = "op_logical_xor",
deps = [
"//executorch/kernels/portable/cpu/pattern:pattern",
":scalar_utils",
"//executorch/kernels/portable/cpu/pattern:logical_op",
"//executorch/kernels/portable/cpu/util:broadcast_util",
"//executorch/kernels/portable/cpu/util:elementwise_util",
],
),
op_target(
Expand Down
Loading