Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions kernels/portable/cpu/op__to_dim_order_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,15 @@ Tensor& _to_dim_order_copy_out(
return out;
}

ET_SWITCH_REALHBBF16_TYPES(
self.scalar_type(),
ctx,
"dim_order_ops::_to_dim_order_copy.out",
CTYPE_IN,
[&] {
ET_SWITCH_REALHBBF16_TYPES(
out.scalar_type(),
ctx,
"dim_order_ops::_to_dim_order_copy.out",
CTYPE_OUT,
[&] { _to_dim_order_copy_impl<CTYPE_IN, CTYPE_OUT>(self, out); });
});
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] =
"dim_order_ops::_to_dim_order_copy.out";

ET_SWITCH_REALHBBF16_TYPES(self.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
ET_SWITCH_REALHBBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
_to_dim_order_copy_impl<CTYPE_IN, CTYPE_OUT>(self, out);
});
});

return out;
}
Expand Down
9 changes: 6 additions & 3 deletions kernels/portable/cpu/op_abs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@ Tensor& abs_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "abs.out";

if (in_is_complex) {
// NOTE: Elected not to add COMPLEXH to dtype_util.h for now
// because I am not planning wide rollout of complex support; if
// we do add SupportedTensorDtypes::COMPLEXH support, then we
// should use it here.
ET_SWITCH_COMPLEXH_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE_IN, [&] {
ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, "abs.out", CTYPE_OUT, [&] {
ET_SWITCH_COMPLEXH_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
apply_unary_map_fn<CTYPE_IN, CTYPE_OUT>(
[](const CTYPE_IN val_in) -> CTYPE_OUT {
return sqrt(
Expand All @@ -55,7 +58,7 @@ Tensor& abs_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
});
});
} else {
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] {
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&] {
apply_unary_map_fn(
[](const CTYPE val_in) {
if (val_in < 0) {
Expand Down
6 changes: 5 additions & 1 deletion kernels/portable/cpu/op_amax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ Tensor& amax_out(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);

ReduceOverDimListPlan plan(in, dim_list);
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "amax.out", CTYPE, [&]() {

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "amax.out";

ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&]() {
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
const bool success = parallel_for_each_reduce_over_dim_list_output_index(
in, dim_list, out, [&](const auto begin, const auto end) {
Expand Down
6 changes: 5 additions & 1 deletion kernels/portable/cpu/op_amin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ Tensor& amin_out(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);

ReduceOverDimListPlan plan(in, dim_list);
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "amin.out", CTYPE, [&]() {

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "amin.out";

ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&]() {
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
const bool success = parallel_for_each_reduce_over_dim_list_output_index(
in, dim_list, out, [&](const auto begin, const auto end) {
Expand Down
24 changes: 15 additions & 9 deletions kernels/portable/cpu/op_any.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ Tensor& any_all_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {

ScalarType in_type = in.scalar_type();
ScalarType out_type = out.scalar_type();
constexpr auto name = "any.all_out";

ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] {
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "any.all_out";

ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, op_name, CTYPE_IN, [&] {
ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, op_name, CTYPE_OUT, [&] {
const auto data_in = in.const_data_ptr<CTYPE_IN>();
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
data_out[0] = static_cast<CTYPE_OUT>(false);
Expand Down Expand Up @@ -79,15 +81,17 @@ Tensor& any_dims_out(

ScalarType in_type = in.scalar_type();
ScalarType out_type = out.scalar_type();
constexpr auto name = "any.dims_out";

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "any.dims_out";

const bool in_not_empty = in.numel() > 0;
std::optional<MapReduceOverDimListPlan> plan;
if ((!dim_list.has_value() || !dim_list.value().empty()) && in_not_empty) {
plan.emplace(in, dim_list);
}
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] {
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, op_name, CTYPE_IN, [&] {
ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, op_name, CTYPE_OUT, [&] {
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
if (dim_list.has_value() && dim_list.value().empty()) {
const CTYPE_IN* in_data = in.const_data_ptr<CTYPE_IN>();
Expand Down Expand Up @@ -144,10 +148,12 @@ Tensor& any_out(

ScalarType in_type = in.scalar_type();
ScalarType out_type = out.scalar_type();
constexpr auto name = "any.out";

ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] {
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "any.out";

ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, op_name, CTYPE_IN, [&] {
ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, op_name, CTYPE_OUT, [&] {
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
const bool success = parallel_for_each_reduce_over_dim_output_index(
in, dim, out, [&](const auto begin, const auto end) {
Expand Down
5 changes: 4 additions & 1 deletion kernels/portable/cpu/op_argmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ Tensor& argmax_out(
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);

ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "argmax.out", CTYPE, [&] {
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "argmax.out";

ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&] {
long* out_data = out.mutable_data_ptr<long>();

const bool success = parallel_for_each_reduce_over_dim_output_index(
Expand Down
5 changes: 4 additions & 1 deletion kernels/portable/cpu/op_argmin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ Tensor& argmin_out(
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);

ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "argmin.out", CTYPE, [&] {
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "argmin.out";

ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&] {
long* out_data = out.mutable_data_ptr<long>();

const bool success = parallel_for_each_reduce_over_dim_output_index(
Expand Down
97 changes: 50 additions & 47 deletions kernels/portable/cpu/op_avg_pool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,53 +67,56 @@ Tensor& avg_pool2d_out(
out);

ScalarType in_type = in.scalar_type();
ET_SWITCH_FLOATHBF16_TYPES_AND(
Long, in_type, ctx, "avg_pool2d.out", CTYPE, [&]() {
if (divisor_override.has_value()) {
int64_t divisor = divisor_override.value();
// If divisor_override is specified, then we don't need to use `count`
// in the calculation. Simply sum x / divisor to get the output.
apply_kernel_2d_reduce_then_map_fn<CTYPE>(
[](const CTYPE in_val,
int64_t in_idx,
CTYPE accum,
int64_t accum_idx) {
// Average pooling does not track indexes, so return 0 for
// accum_idx
return std::tuple<CTYPE, int64_t>(in_val + accum, 0);
},
[divisor](const int64_t count, const CTYPE accum) {
return accum / static_cast<CTYPE>(divisor);
},
count_include_pad,
in,
kernel_size,
stride,
padding,
{},
out);
} else {
apply_kernel_2d_reduce_then_map_fn<CTYPE>(
[](const CTYPE in_val,
int64_t in_idx,
CTYPE accum,
int64_t accum_idx) {
// Average pooling does not track indexes, so return 0 for
// accum_idx
return std::tuple<CTYPE, int64_t>(in_val + accum, 0);
},
[](const int64_t count, const CTYPE accum) {
return accum / static_cast<CTYPE>(count);
},
count_include_pad,
in,
kernel_size,
stride,
padding,
{},
out);
}
});

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "avg_pool2d.out";

ET_SWITCH_FLOATHBF16_TYPES_AND(Long, in_type, ctx, op_name, CTYPE, [&]() {
if (divisor_override.has_value()) {
int64_t divisor = divisor_override.value();
// If divisor_override is specified, then we don't need to use `count`
// in the calculation. Simply sum x / divisor to get the output.
apply_kernel_2d_reduce_then_map_fn<CTYPE>(
[](const CTYPE in_val,
int64_t in_idx,
CTYPE accum,
int64_t accum_idx) {
// Average pooling does not track indexes, so return 0 for
// accum_idx
return std::tuple<CTYPE, int64_t>(in_val + accum, 0);
},
[divisor](const int64_t count, const CTYPE accum) {
return accum / static_cast<CTYPE>(divisor);
},
count_include_pad,
in,
kernel_size,
stride,
padding,
{},
out);
} else {
apply_kernel_2d_reduce_then_map_fn<CTYPE>(
[](const CTYPE in_val,
int64_t in_idx,
CTYPE accum,
int64_t accum_idx) {
// Average pooling does not track indexes, so return 0 for
// accum_idx
return std::tuple<CTYPE, int64_t>(in_val + accum, 0);
},
[](const int64_t count, const CTYPE accum) {
return accum / static_cast<CTYPE>(count);
},
count_include_pad,
in,
kernel_size,
stride,
padding,
{},
out);
}
});

return out;
}
Expand Down
4 changes: 3 additions & 1 deletion kernels/portable/cpu/op_bitwise_not.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@ bitwise_not_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "bitwise_not.out";
if (in.scalar_type() == executorch::aten::ScalarType::Bool) {
apply_unary_map_fn(
[](const bool val_in) { return !val_in; },
in.const_data_ptr<bool>(),
out.mutable_data_ptr<bool>(),
in.numel());
} else if (isIntegralType(in.scalar_type(), /*includeBool=*/false)) {
ET_SWITCH_INT_TYPES(in.scalar_type(), ctx, "bitwise_not.out", CTYPE, [&] {
ET_SWITCH_INT_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&] {
apply_unary_map_fn(
[](const CTYPE val_in) { return ~val_in; },
in.const_data_ptr<CTYPE>(),
Expand Down
7 changes: 4 additions & 3 deletions kernels/portable/cpu/op_bmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,17 @@ Tensor& bmm_out(
InvalidArgument,
out);

constexpr auto name = "bmm.out";
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "bmm.out";

auto in_type = in.scalar_type();

if (executorch::runtime::isComplexType(in_type)) {
ET_SWITCH_COMPLEXH_TYPES(in_type, ctx, name, CTYPE, [&]() {
ET_SWITCH_COMPLEXH_TYPES(in_type, ctx, op_name, CTYPE, [&]() {
internal::bmm_out_impl<CTYPE>(in, mat2, out);
});
} else {
ET_SWITCH_REALH_TYPES(in_type, ctx, name, CTYPE, [&]() {
ET_SWITCH_REALH_TYPES(in_type, ctx, op_name, CTYPE, [&]() {
internal::bmm_out_impl<CTYPE>(in, mat2, out);
});
}
Expand Down
9 changes: 6 additions & 3 deletions kernels/portable/cpu/op_cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,17 @@ Tensor& cat_out(
const bool out_is_complex =
executorch::runtime::isComplexType(out.scalar_type());

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "cat.out";

if (out_is_complex) {
// TODO: The current support for complex dtype enforces that input and
// output tensors have the same dtype. Support mixed dtypes in the future.
for (size_t i = 0; i < ninputs; ++i) {
const auto in_type = tensors[i].scalar_type();
ET_KERNEL_CHECK(ctx, out_type == in_type, InvalidArgument, out);
}
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "cat.out", CTYPE, [&] {
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, op_name, CTYPE, [&] {
CTYPE* out_ptr = out.mutable_data_ptr<CTYPE>();
for (size_t i = 0; i < outer; ++i) {
for (size_t j = 0; j < ninputs; ++j) {
Expand All @@ -82,12 +85,12 @@ Tensor& cat_out(
}
});
} else {
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] {
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, op_name, CTYPE_OUT, [&] {
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
for (size_t i = 0; i < outer; ++i) {
for (size_t j = 0; j < ninputs; ++j) {
const auto in_type = tensors[j].scalar_type();
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] {
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, op_name, CTYPE_IN, [&] {
if (tensors[j].numel() == 0) {
return;
}
Expand Down
6 changes: 4 additions & 2 deletions kernels/portable/cpu/op_cdist_forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,12 @@ Tensor& _cdist_forward_out(
out);

ScalarType out_type = out.scalar_type();
constexpr auto name = "_cdist_forward.out";

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "_cdist_forward.out";

ET_SWITCH_FLOATHBF16_TYPES(
out_type, ctx, name, CTYPE, [&] { cdist<CTYPE>(x1, x2, out, p); });
out_type, ctx, op_name, CTYPE, [&] { cdist<CTYPE>(x1, x2, out, p); });

return out;
}
Expand Down
7 changes: 5 additions & 2 deletions kernels/portable/cpu/op_clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,19 @@ ET_NODISCARD bool check_bounds(
const char* val_name) {
auto is_valid = true;

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "clamp.out";

if (isIntegralType(out_type, /*includeBool=*/false)) {
const long val_long = utils::scalar_to<long>(val_scalar);
ET_SWITCH_INT_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() {
ET_SWITCH_INT_TYPES(out_type, ctx, op_name, CTYPE_OUT, [&]() {
if (is_out_of_bounds<CTYPE_OUT, long>(val_long)) {
ET_LOG(Error, "%s value out of bounds", val_name);
is_valid = false;
}
});
} else if (isFloatingType(out_type)) {
ET_SWITCH_FLOATHBF16_TYPES(out_type, ctx, "clamp.out", CTYPE_OUT, [&]() {
ET_SWITCH_FLOATHBF16_TYPES(out_type, ctx, op_name, CTYPE_OUT, [&]() {
const double val_double = utils::scalar_to<double>(val_scalar);
if (std::isfinite(val_double) &&
is_out_of_bounds<CTYPE_OUT, double>(val_double)) {
Expand Down
5 changes: 4 additions & 1 deletion kernels/portable/cpu/op_constant_pad_nd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,10 @@ Tensor& constant_pad_nd_out(

ScalarType in_type = in.scalar_type();

ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "constant_pad_nd.out", CTYPE, [&]() {
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "constant_pad_nd.out";

ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, op_name, CTYPE, [&]() {
auto opt_value_casted =
utils::internal::check_overflow_scalar_cast<CTYPE>(value);
ET_KERNEL_CHECK(ctx, opt_value_casted.has_value(), InvalidArgument, );
Expand Down
Loading
Loading