Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion kernels/optimized/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
endif()

set(_common_compile_options -Wno-deprecated-declarations)
set(_common_compile_options
$<$<CXX_COMPILER_ID:MSVC>:/wd4996>
$<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-Wno-deprecated-declarations>
)

# Note for apple platform we can rely on Accelerate framework Will come back to
# this
Expand Down
5 changes: 2 additions & 3 deletions kernels/optimized/cpu/op_bmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,14 @@ Tensor& opt_bmm_out(
ET_KERNEL_CHECK(
ctx, check_bmm_out_args(self, mat2, out), InvalidArgument, out);

constexpr auto name = "bmm.out";
auto self_type = self.scalar_type();

if (executorch::runtime::isComplexType(self_type)) {
ET_SWITCH_COMPLEXH_TYPES(self_type, ctx, name, CTYPE, [&]() {
ET_SWITCH_COMPLEXH_TYPES(self_type, ctx, "bmm.out", CTYPE, [&]() {
bmm_kernel<CTYPE>(self, mat2, out);
});
} else {
ET_SWITCH_REALHBF16_TYPES(self_type, ctx, name, CTYPE, [&]() {
ET_SWITCH_REALHBF16_TYPES(self_type, ctx, "bmm.out", CTYPE, [&]() {
bmm_kernel<CTYPE>(self, mat2, out);
});
}
Expand Down
4 changes: 1 addition & 3 deletions kernels/portable/cpu/op_masked_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,11 @@ Tensor& masked_scatter_out(
InvalidArgument,
out);

constexpr auto op_name = "masked_scatter.out";

int64_t idx = 0;
int64_t src_numel = src.numel();
bool src_numel_check = true;

ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, op_name, CTYPE, [&]() {
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "masked_scatter.out", CTYPE, [&]() {
const CTYPE* const src_data = src.const_data_ptr<CTYPE>();
apply_binary_elementwise_fn<CTYPE, bool, CTYPE>(
[src_data, &idx, &src_numel, &src_numel_check](
Expand Down
43 changes: 27 additions & 16 deletions kernels/portable/cpu/op_topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
* LICENSE file in the root directory of this source tree.
*/

#include <c10/util/irange.h>
#include <cmath>
#include <tuple>

#include <executorch/kernels/portable/cpu/util/math_util.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/kernel/kernel_includes.h>

#include <c10/util/irange.h>
#include <cmath>
#include <tuple>

namespace torch {
namespace executor {
namespace native {
Expand Down Expand Up @@ -118,19 +118,30 @@ void perform_topk(
}

// Perform topk on the queue
const auto elem_greater = [](const elem_t& x, const elem_t& y) -> bool {
return float_less_than(y.first, x.first);
};
const auto elem_less = [](const elem_t& x, const elem_t& y) -> bool {
return float_less_than(x.first, y.first);
};
const auto cmp = largest ? elem_greater : elem_less;
if (use_partial_sort) {
std::partial_sort(queue, queue + k, queue + dim_size, cmp);
if (largest) {
const auto elem_greater = [](const elem_t& x, const elem_t& y) -> bool {
return float_less_than(y.first, x.first);
};
if (use_partial_sort) {
std::partial_sort(queue, queue + k, queue + dim_size, elem_greater);
} else {
std::nth_element(
queue, queue + k - 1, queue + dim_size, elem_greater);
if (sorted) {
std::sort(queue, queue + k - 1, elem_greater);
}
}
} else {
std::nth_element(queue, queue + k - 1, queue + dim_size, cmp);
if (sorted) {
std::sort(queue, queue + k - 1, cmp);
const auto elem_less = [](const elem_t& x, const elem_t& y) -> bool {
return float_less_than(x.first, y.first);
};
if (use_partial_sort) {
std::partial_sort(queue, queue + k, queue + dim_size, elem_less);
} else {
std::nth_element(queue, queue + k - 1, queue + dim_size, elem_less);
if (sorted) {
std::sort(queue, queue + k - 1, elem_less);
}
}
}

Expand Down
14 changes: 7 additions & 7 deletions kernels/portable/cpu/op_view_as_real_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ Tensor& view_as_real_copy_out(
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);

constexpr auto op_name = "view_as_real_copy.out";
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh missed one.


ET_SWITCH_COMPLEXH_TYPES(self.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
_to_impl<CTYPE_IN, CTYPE_OUT>(self, out);
});
});
ET_SWITCH_COMPLEXH_TYPES(
self.scalar_type(), ctx, "view_as_real_copy.out", CTYPE_IN, [&] {
ET_SWITCH_FLOATH_TYPES(
out.scalar_type(), ctx, "view_as_real_copy.out", CTYPE_OUT, [&] {
_to_impl<CTYPE_IN, CTYPE_OUT>(self, out);
});
});

return out;
}
Expand Down
5 changes: 4 additions & 1 deletion kernels/portable/cpu/util/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ endif()

list(TRANSFORM _kernels_util_all_deps__srcs PREPEND "${EXECUTORCH_ROOT}/")

set(_common_compile_options -Wno-deprecated-declarations)
set(_common_compile_options
$<$<CXX_COMPILER_ID:MSVC>:/wd4996>
$<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-Wno-deprecated-declarations>
)

add_library(kernels_util_all_deps ${_kernels_util_all_deps__srcs})
target_link_libraries(kernels_util_all_deps PRIVATE executorch_core)
Expand Down
48 changes: 23 additions & 25 deletions kernels/portable/cpu/util/elementwise_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ inline void dtype_specialized_elementwise_fn_impl(
static_assert(
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
...));
constexpr auto kNumInputs = sizeof...(inputs);
// All inputs must be of type CTYPE_COMPUTE.
ET_DCHECK(
((inputs.first->scalar_type() ==
Expand All @@ -105,8 +104,9 @@ inline void dtype_specialized_elementwise_fn_impl(
out.numel(),
::executorch::extension::internal::GRAIN_SIZE,
[&](const auto begin, const auto end) {
std::array<const CTYPE_COMPUTE*, kNumInputs> inputs_data_ptrs = {
inputs.first->template const_data_ptr<CTYPE_COMPUTE>()...};
std::array<const CTYPE_COMPUTE*, sizeof...(inputs)>
inputs_data_ptrs = {
inputs.first->template const_data_ptr<CTYPE_COMPUTE>()...};

CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();

Expand All @@ -119,11 +119,11 @@ inline void dtype_specialized_elementwise_fn_impl(
// small-sized tests will test whether using Vectorized broke our
// lambda.
#ifndef NDEBUG
std::array<Vec, kNumInputs> loaded_inputs{};
std::array<Vec, sizeof...(inputs)> loaded_inputs{};
#else // NDEBUG
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs{};
std::array<CTYPE_COMPUTE, sizeof...(inputs)> loaded_inputs{};
#endif // NDEBUG
for (const auto input_idx : c10::irange(kNumInputs)) {
for (const auto input_idx : c10::irange(sizeof...(inputs))) {
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
}
#ifndef NDEBUG
Expand All @@ -136,8 +136,8 @@ inline void dtype_specialized_elementwise_fn_impl(
// Main vectorized loop.
for (auto idx = vectorized_begin; idx < vectorized_end;
idx += Vec::size()) {
std::array<Vec, kNumInputs> loaded_vec_inputs{};
for (const auto input_idx : c10::irange(kNumInputs)) {
std::array<Vec, sizeof...(inputs)> loaded_vec_inputs{};
for (const auto input_idx : c10::irange(sizeof...(inputs))) {
loaded_vec_inputs[input_idx] =
Vec::loadu(&inputs_data_ptrs[input_idx][idx]);
}
Expand All @@ -148,11 +148,11 @@ inline void dtype_specialized_elementwise_fn_impl(
// Scalar epilogue.
for (const auto idx : c10::irange(vectorized_end, end)) {
#ifndef NDEBUG
std::array<Vec, kNumInputs> loaded_inputs{};
std::array<Vec, sizeof...(inputs)> loaded_inputs{};
#else // NDEBUG
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs{};
std::array<CTYPE_COMPUTE, sizeof...(inputs)> loaded_inputs{};
#endif // NDEBUG
for (const auto input_idx : c10::irange(kNumInputs)) {
for (const auto input_idx : c10::irange(sizeof...(inputs))) {
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
}
#ifndef NDEBUG
Expand All @@ -172,20 +172,20 @@ inline void dtype_specialized_elementwise_fn_impl(
out.numel(),
::executorch::extension::internal::GRAIN_SIZE,
[&](const auto begin, const auto end) {
std::array<const CTYPE_COMPUTE*, kNumInputs> inputs_data_ptrs = {
std::array<const CTYPE_COMPUTE*, sizeof...(inputs)> inputs_data_ptrs = {
inputs.first->template const_data_ptr<CTYPE_COMPUTE>()...};

CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();

const auto range =
BroadcastIndexesRange<kNumInputs, support_noncontiguous_tensors>(
out, (*inputs.first)...);
const auto range = BroadcastIndexesRange<
sizeof...(inputs),
support_noncontiguous_tensors>(out, (*inputs.first)...);
auto begin_it = range.begin();
begin_it += begin;
for (; (*begin_it)[0] < end; ++begin_it) {
const auto& indexes = *begin_it;
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs{};
for (const auto idx : c10::irange(kNumInputs)) {
std::array<CTYPE_COMPUTE, sizeof...(inputs)> loaded_inputs{};
for (const auto idx : c10::irange(sizeof...(inputs))) {
loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1]];
}
data_out[indexes[0]] = std::apply(compute_fun, loaded_inputs);
Expand Down Expand Up @@ -229,14 +229,12 @@ inline void apply_elementwise_fn_generic_impl(
const Tensor& out,
SupportedTensorDtypes out_dtypes,
Args... inputs) {
constexpr auto kNumInputs = sizeof...(inputs);

struct InputInfo {
load_to_compute_fn<CTYPE_COMPUTE> load_to_compute;
const char* data_ptr;
ssize_t element_size;
};
std::array<InputInfo, kNumInputs> inputs_info = {(InputInfo{
std::array<InputInfo, sizeof...(inputs)> inputs_info = {(InputInfo{
internal::get_load_to_compute_fn<CTYPE_COMPUTE, op_name>(
ctx, *inputs.first, inputs.second),
reinterpret_cast<const char*>(inputs.first->const_data_ptr()),
Expand All @@ -254,15 +252,15 @@ inline void apply_elementwise_fn_generic_impl(
out.numel(),
::executorch::extension::internal::GRAIN_SIZE,
[&](const auto begin, const auto end) {
const auto range =
BroadcastIndexesRange<kNumInputs, support_noncontiguous_tensors>(
out, (*inputs.first)...);
const auto range = BroadcastIndexesRange<
sizeof...(inputs),
support_noncontiguous_tensors>(out, (*inputs.first)...);
auto begin_it = range.begin();
begin_it += begin;
for (; (*begin_it)[0] < end; ++begin_it) {
const auto& indexes = *begin_it;
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs{};
for (const auto idx : c10::irange(kNumInputs)) {
std::array<CTYPE_COMPUTE, sizeof...(inputs)> loaded_inputs{};
for (const auto idx : c10::irange(sizeof...(inputs))) {
const auto& input_info = inputs_info[idx];
loaded_inputs[idx] = input_info.load_to_compute(
&input_info
Expand Down
Loading
Loading