Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions kernels/portable/cpu/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,6 @@ Tensor& add_out(
const Tensor& b,
const Scalar& alpha,
Tensor& out) {
ET_KERNEL_CHECK(
ctx,
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
executorch::runtime::tensor_is_realhbbf16_type(b) &&
executorch::runtime::tensor_is_realhbbf16_type(out)),
InvalidArgument,
out);

// Common Dtype
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());

Expand Down Expand Up @@ -64,6 +56,7 @@ Tensor& add_out(
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
return val_a + val_alpha * val_b;
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
Expand All @@ -81,13 +74,6 @@ Tensor& add_scalar_out(
const Scalar& b,
const Scalar& alpha,
Tensor& out) {
ET_KERNEL_CHECK(
ctx,
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
executorch::runtime::tensor_is_realhbbf16_type(out)),
InvalidArgument,
out);

// Common Dtype
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);

Expand Down Expand Up @@ -120,6 +106,7 @@ Tensor& add_scalar_out(
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
return val_a + val_alpha * val_b;
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
out,
Expand Down
18 changes: 2 additions & 16 deletions kernels/portable/cpu/op_clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,6 @@ Tensor& clamp_out(
const exec_aten::optional<Scalar>& min_opt,
const exec_aten::optional<Scalar>& max_opt,
Tensor& out) {
ET_KERNEL_CHECK(
ctx,
(executorch::runtime::tensor_is_realhbbf16_type(in) &&
executorch::runtime::tensor_is_realhbbf16_type(out)),
InvalidArgument,
out);

bool has_min = min_opt.has_value();
bool has_max = max_opt.has_value();

Expand Down Expand Up @@ -154,6 +147,7 @@ Tensor& clamp_out(
}
return val_out;
},
ctx,
in,
utils::SupportedTensorDtypes::REALHBBF16,
out,
Expand Down Expand Up @@ -182,15 +176,6 @@ Tensor& clamp_tensor_out(
const Tensor& min = has_min ? min_opt.value() : in;
const Tensor& max = has_max ? max_opt.value() : in;

ET_KERNEL_CHECK(
ctx,
(executorch::runtime::tensor_is_realhbbf16_type(in) &&
executorch::runtime::tensor_is_realhbbf16_type(min) &&
executorch::runtime::tensor_is_realhbbf16_type(max) &&
executorch::runtime::tensor_is_realhbbf16_type(out)),
InvalidArgument,
out);

// Common Dtype
ScalarType common_type = in.scalar_type();
if (has_min) {
Expand Down Expand Up @@ -239,6 +224,7 @@ Tensor& clamp_tensor_out(
}
return val_out;
},
ctx,
in,
utils::SupportedTensorDtypes::REALHBBF16,
min,
Expand Down
11 changes: 1 addition & 10 deletions kernels/portable/cpu/op_where.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,6 @@ Tensor& where_out(
const Tensor& a,
const Tensor& b,
Tensor& out) {
ET_KERNEL_CHECK(
ctx,
((cond.scalar_type() == ScalarType::Bool ||
cond.scalar_type() == ScalarType::Byte) &&
executorch::runtime::tensor_is_realhbbf16_type(a) &&
executorch::runtime::tensor_is_realhbbf16_type(b) &&
executorch::runtime::tensor_is_realhbbf16_type(out)),
InvalidArgument,
out);

// Common Dtype
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());

Expand Down Expand Up @@ -57,6 +47,7 @@ Tensor& where_out(
[](const CTYPE_COMPUTE val_a,
const CTYPE_COMPUTE val_b,
const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
Expand Down
50 changes: 50 additions & 0 deletions kernels/portable/cpu/util/elementwise_util.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/runtime/kernel/kernel_runtime_context.h>

namespace torch {
namespace executor {
namespace native {
namespace utils {
namespace internal {

bool check_tensor_dtype(
const Tensor t,
SupportedTensorDtypes dtypes,
const ScalarType compute_type) {
switch (dtypes) {
case SupportedTensorDtypes::REALHBBF16:
return executorch::runtime::tensor_is_realhbbf16_type(t);
case SupportedTensorDtypes::BOOL_OR_BYTE:
return (
executorch::runtime::tensor_is_type(t, ScalarType::Bool) ||
executorch::runtime::tensor_is_type(t, ScalarType::Byte));
case SupportedTensorDtypes::SAME_AS_COMPUTE:
return executorch::runtime::tensor_is_type(t, compute_type);
case SupportedTensorDtypes::SAME_AS_COMMON: {
if (compute_type == ScalarType::Float) {
return (
executorch::runtime::tensor_is_type(t, ScalarType::Float) ||
executorch::runtime::tensor_is_type(t, ScalarType::Half) ||
executorch::runtime::tensor_is_type(t, ScalarType::BFloat16));
} else {
return executorch::runtime::tensor_is_type(t, compute_type);
}
}
}
ET_CHECK(false);
return false;
}

} // namespace internal
} // namespace utils
} // namespace native
} // namespace executor
} // namespace torch
41 changes: 38 additions & 3 deletions kernels/portable/cpu/util/elementwise_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,29 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
return nullptr;
}

bool check_tensor_dtype(
const Tensor t,
SupportedTensorDtypes dtypes,
const ScalarType compute_type);

} // namespace internal

template <typename CTYPE_COMMON, const char* op_name, typename Op>
inline void apply_unitensor_elementwise_fn(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& a,
SupportedTensorDtypes a_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;

ET_KERNEL_CHECK(
ctx,
(internal::check_tensor_dtype(a, a_dtypes, compute_type) &&
internal::check_tensor_dtype(out, out_dtypes, compute_type)),
InvalidArgument, );

const auto load_a_to_common =
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
const auto store_common_to_out =
Expand All @@ -263,12 +277,22 @@ inline void apply_unitensor_elementwise_fn(
template <typename CTYPE_COMMON, const char* op_name, typename Op>
inline void apply_bitensor_elementwise_fn(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& a,
SupportedTensorDtypes a_dtypes,
const Tensor& b,
SupportedTensorDtypes b_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;

ET_KERNEL_CHECK(
ctx,
(internal::check_tensor_dtype(a, a_dtypes, compute_type) &&
internal::check_tensor_dtype(b, b_dtypes, compute_type) &&
internal::check_tensor_dtype(out, out_dtypes, compute_type)),
InvalidArgument, );

const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted);
Expand Down Expand Up @@ -312,9 +336,9 @@ inline void apply_bitensor_elementwise_fn(
}

/**
* Useful for tri-tensor elementwise operators. For each element of the inputs,
* perform a computation and write to the corresponding element of the output.
* Tensor broadcasting is applied wherever it is required.
* Useful for tri-tensor elementwise operators. For each element of the
* inputs, perform a computation and write to the corresponding element of the
* output. Tensor broadcasting is applied wherever it is required.
*
* In order to mitigate build time cost (straightforwardly |CTYPE_A| *
* |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun
Expand All @@ -334,6 +358,7 @@ inline void apply_bitensor_elementwise_fn(
template <typename CTYPE_COMMON, const char* op_name, typename Op>
inline void apply_tritensor_elementwise_fn(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& a,
SupportedTensorDtypes a_dtypes,
const Tensor& b,
Expand All @@ -342,6 +367,16 @@ inline void apply_tritensor_elementwise_fn(
SupportedTensorDtypes c_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;

ET_KERNEL_CHECK(
ctx,
(internal::check_tensor_dtype(a, a_dtypes, compute_type) &&
internal::check_tensor_dtype(b, b_dtypes, compute_type) &&
internal::check_tensor_dtype(c, c_dtypes, compute_type) &&
internal::check_tensor_dtype(out, out_dtypes, compute_type)),
InvalidArgument, );

const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
const bool c_is_broadcasted = !out.sizes().equals(c.sizes());
Expand Down
1 change: 1 addition & 0 deletions kernels/portable/cpu/util/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def define_common_targets():

runtime.cxx_library(
name = "elementwise_util",
srcs = ["elementwise_util.cpp"],
exported_headers = [
"elementwise_util.h",
],
Expand Down
10 changes: 10 additions & 0 deletions runtime/core/exec_aten/util/tensor_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,16 @@ inline bool tensor_is_bool_type(exec_aten::Tensor t) {
return true;
}

inline bool tensor_is_type(exec_aten::Tensor t, exec_aten::ScalarType dtype) {
ET_LOG_MSG_AND_RETURN_IF_FALSE(
t.scalar_type() == dtype,
"Expected to find %s type, but tensor has type %s",
torch::executor::toString(dtype),
torch::executor::toString(t.scalar_type()));

return true;
}

inline bool tensor_is_integral_type(
exec_aten::Tensor t,
bool includeBool = false) {
Expand Down
Loading