Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 35 additions & 21 deletions kernels/portable/cpu/op_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <cstring>

#include <executorch/kernels/portable/cpu/util/dtype_util.h>
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
Expand All @@ -32,15 +33,17 @@ namespace {
* in_C_per_group x in_H x in_W, to compute an out channel of size 1 x out_H x
* out_W.
*/
template <typename CTYPE, typename CTYPE_BIAS>
template <typename CTYPE, typename LoadFn = CTYPE (*)(const void*)>
void conv2d_impl(
const CTYPE* const in_ptr,
SizesArrayRef in_sizes,
StridesArrayRef in_strides,
const CTYPE* const w_ptr,
SizesArrayRef w_sizes,
StridesArrayRef w_strides,
const CTYPE_BIAS* const bias_ptr,
const exec_aten::optional<Tensor>& bias,
const char* const bias_ptr,
LoadFn load_bias,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
Expand Down Expand Up @@ -128,7 +131,7 @@ void conv2d_impl(
}

if (bias_ptr != nullptr) {
accum += convert<CTYPE, CTYPE_BIAS>(bias_ptr[out_c]);
accum += load_bias(&bias_ptr[out_c * bias.value().element_size()]);
}
size_t out_idx =
calculate_linear_index(out_coord, out_strides.data(), 4);
Expand Down Expand Up @@ -185,11 +188,12 @@ void conv2d_impl(
}
}

template <typename CTYPE, typename CTYPE_BIAS>
template <typename CTYPE, typename LoadFn = CTYPE (*)(const void*)>
void convolution_wrapper(
const Tensor& in,
const Tensor& weight,
const exec_aten::optional<Tensor>& bias,
LoadFn load_bias,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
Expand Down Expand Up @@ -280,8 +284,9 @@ void convolution_wrapper(
CTYPE* const out_ptr = out.mutable_data_ptr<CTYPE>();
const CTYPE* const in_ptr = in.const_data_ptr<CTYPE>();
const CTYPE* const w_ptr = weight.const_data_ptr<CTYPE>();
const CTYPE_BIAS* const bias_ptr =
bias.has_value() ? bias.value().const_data_ptr<CTYPE_BIAS>() : nullptr;
const char* const bias_ptr = bias.has_value()
? reinterpret_cast<const char*>(bias.value().const_data_ptr())
: nullptr;

size_t out_N = out.size(0);
size_t out_C = out.size(1);
Expand All @@ -296,8 +301,9 @@ void convolution_wrapper(
} else {
// If bias is present, we initialize the output to the bias value
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
out_ptr[out_ix] = convert<CTYPE, CTYPE_BIAS>(
bias_ptr[(out_ix / out_strides[1]) % out_C]);
out_ptr[out_ix] = load_bias(&bias_ptr
[((out_ix / out_strides[1]) % out_C) *
bias.value().element_size()]);
}
}
}
Expand All @@ -316,7 +322,9 @@ void convolution_wrapper(
w_ptr,
weight_sizes,
{weight_strides, 4},
bias,
bias_ptr,
load_bias,
stride_,
padding_,
dilation_,
Expand Down Expand Up @@ -398,19 +406,25 @@ Tensor& convolution_out(
return out;
}

ScalarType in_type = in.scalar_type();
ScalarType bias_type = in_type;
if (bias.has_value()) {
bias_type = bias.value().scalar_type();
}

constexpr auto name = "convolution.out";

ET_SWITCH_REALH_TYPES(in_type, ctx, name, CTYPE, [&]() {
ET_SWITCH_REALHB_TYPES(bias_type, ctx, name, CTYPE_BIAS, [&]() {
convolution_wrapper<CTYPE, CTYPE_BIAS>(
in, weight, bias, stride, padding, dilation, transposed, groups, out);
});
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char name[] = "convolution.out";

ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
const auto load_bias = bias.has_value()
? utils::internal::get_load_to_common_fn<CTYPE, name>(
bias.value(), utils::SupportedTensorDtypes::REALHBF16)
: nullptr;
convolution_wrapper<CTYPE>(
in,
weight,
bias,
load_bias,
stride,
padding,
dilation,
transposed,
groups,
out);
});

return out;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ ATEN_OPS = (
op_target(
name = "op_convolution",
deps = [
"//executorch/kernels/portable/cpu/util:dtype_util",
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
":vec_ops",
],
Expand Down
Loading