Skip to content

Commit 904b2b4

Browse files
committed
more statics and explicit typing for lambdas
1 parent 83cf868 commit 904b2b4

File tree

4 files changed

+47
-49
lines changed

4 files changed

+47
-49
lines changed

kernels/optimized/cpu/op_bmm.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,14 @@ Tensor& opt_bmm_out(
152152

153153
auto self_type = self.scalar_type();
154154

155+
static constexpr auto name = "bmm.out";
156+
155157
if (executorch::runtime::isComplexType(self_type)) {
156-
ET_SWITCH_COMPLEXH_TYPES(self_type, ctx, "bmm.out", CTYPE, [&]() {
158+
ET_SWITCH_COMPLEXH_TYPES(self_type, ctx, name, CTYPE, [&]() {
157159
bmm_kernel<CTYPE>(self, mat2, out);
158160
});
159161
} else {
160-
ET_SWITCH_REALHBF16_TYPES(self_type, ctx, "bmm.out", CTYPE, [&]() {
162+
ET_SWITCH_REALHBF16_TYPES(self_type, ctx, name, CTYPE, [&]() {
161163
bmm_kernel<CTYPE>(self, mat2, out);
162164
});
163165
}

kernels/portable/cpu/op_masked_scatter.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ Tensor& masked_scatter_out(
4545
int64_t src_numel = src.numel();
4646
bool src_numel_check = true;
4747

48-
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "masked_scatter.out", CTYPE, [&]() {
48+
static constexpr auto name = "masked_scatter.out";
49+
50+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE, [&]() {
4951
const CTYPE* const src_data = src.const_data_ptr<CTYPE>();
5052
apply_binary_elementwise_fn<CTYPE, bool, CTYPE>(
5153
[src_data, &idx, &src_numel, &src_numel_check](

kernels/portable/cpu/op_topk.cpp

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -118,30 +118,22 @@ void perform_topk(
118118
}
119119

120120
// Perform topk on the queue
121-
if (largest) {
122-
const auto elem_greater = [](const elem_t& x, const elem_t& y) -> bool {
123-
return float_less_than(y.first, x.first);
124-
};
125-
if (use_partial_sort) {
126-
std::partial_sort(queue, queue + k, queue + dim_size, elem_greater);
127-
} else {
128-
std::nth_element(
129-
queue, queue + k - 1, queue + dim_size, elem_greater);
130-
if (sorted) {
131-
std::sort(queue, queue + k - 1, elem_greater);
132-
}
133-
}
121+
bool (*elem_greater)(const elem_t&, const elem_t&) =
122+
[](const elem_t& x, const elem_t& y) -> bool {
123+
return float_less_than(y.first, x.first);
124+
};
125+
bool (*elem_less)(const elem_t&, const elem_t&) =
126+
[](const elem_t& x, const elem_t& y) -> bool {
127+
return float_less_than(x.first, y.first);
128+
};
129+
bool (*cmp)(const elem_t&, const elem_t&) =
130+
largest ? elem_greater : elem_less;
131+
if (use_partial_sort) {
132+
std::partial_sort(queue, queue + k, queue + dim_size, cmp);
134133
} else {
135-
const auto elem_less = [](const elem_t& x, const elem_t& y) -> bool {
136-
return float_less_than(x.first, y.first);
137-
};
138-
if (use_partial_sort) {
139-
std::partial_sort(queue, queue + k, queue + dim_size, elem_less);
140-
} else {
141-
std::nth_element(queue, queue + k - 1, queue + dim_size, elem_less);
142-
if (sorted) {
143-
std::sort(queue, queue + k - 1, elem_less);
144-
}
134+
std::nth_element(queue, queue + k - 1, queue + dim_size, cmp);
135+
if (sorted) {
136+
std::sort(queue, queue + k - 1, cmp);
145137
}
146138
}
147139

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ inline void dtype_specialized_elementwise_fn_impl(
8585
static_assert(
8686
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
8787
...));
88+
static constexpr auto kNumInputs = sizeof...(inputs);
8889
// All inputs must be of type CTYPE_COMPUTE.
8990
ET_DCHECK(
9091
((inputs.first->scalar_type() ==
@@ -104,9 +105,8 @@ inline void dtype_specialized_elementwise_fn_impl(
104105
out.numel(),
105106
::executorch::extension::internal::GRAIN_SIZE,
106107
[&](const auto begin, const auto end) {
107-
std::array<const CTYPE_COMPUTE*, sizeof...(inputs)>
108-
inputs_data_ptrs = {
109-
inputs.first->template const_data_ptr<CTYPE_COMPUTE>()...};
108+
std::array<const CTYPE_COMPUTE*, kNumInputs> inputs_data_ptrs = {
109+
inputs.first->template const_data_ptr<CTYPE_COMPUTE>()...};
110110

111111
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
112112

@@ -119,11 +119,11 @@ inline void dtype_specialized_elementwise_fn_impl(
119119
// small-sized tests will test whether using Vectorized broke our
120120
// lambda.
121121
#ifndef NDEBUG
122-
std::array<Vec, sizeof...(inputs)> loaded_inputs{};
122+
std::array<Vec, kNumInputs> loaded_inputs{};
123123
#else // NDEBUG
124-
std::array<CTYPE_COMPUTE, sizeof...(inputs)> loaded_inputs{};
124+
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs{};
125125
#endif // NDEBUG
126-
for (const auto input_idx : c10::irange(sizeof...(inputs))) {
126+
for (const auto input_idx : c10::irange(kNumInputs)) {
127127
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
128128
}
129129
#ifndef NDEBUG
@@ -136,8 +136,8 @@ inline void dtype_specialized_elementwise_fn_impl(
136136
// Main vectorized loop.
137137
for (auto idx = vectorized_begin; idx < vectorized_end;
138138
idx += Vec::size()) {
139-
std::array<Vec, sizeof...(inputs)> loaded_vec_inputs{};
140-
for (const auto input_idx : c10::irange(sizeof...(inputs))) {
139+
std::array<Vec, kNumInputs> loaded_vec_inputs{};
140+
for (const auto input_idx : c10::irange(kNumInputs)) {
141141
loaded_vec_inputs[input_idx] =
142142
Vec::loadu(&inputs_data_ptrs[input_idx][idx]);
143143
}
@@ -148,11 +148,11 @@ inline void dtype_specialized_elementwise_fn_impl(
148148
// Scalar epilogue.
149149
for (const auto idx : c10::irange(vectorized_end, end)) {
150150
#ifndef NDEBUG
151-
std::array<Vec, sizeof...(inputs)> loaded_inputs{};
151+
std::array<Vec, kNumInputs> loaded_inputs{};
152152
#else // NDEBUG
153-
std::array<CTYPE_COMPUTE, sizeof...(inputs)> loaded_inputs{};
153+
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs{};
154154
#endif // NDEBUG
155-
for (const auto input_idx : c10::irange(sizeof...(inputs))) {
155+
for (const auto input_idx : c10::irange(kNumInputs)) {
156156
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
157157
}
158158
#ifndef NDEBUG
@@ -172,20 +172,20 @@ inline void dtype_specialized_elementwise_fn_impl(
172172
out.numel(),
173173
::executorch::extension::internal::GRAIN_SIZE,
174174
[&](const auto begin, const auto end) {
175-
std::array<const CTYPE_COMPUTE*, sizeof...(inputs)> inputs_data_ptrs = {
175+
std::array<const CTYPE_COMPUTE*, kNumInputs> inputs_data_ptrs = {
176176
inputs.first->template const_data_ptr<CTYPE_COMPUTE>()...};
177177

178178
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
179179

180-
const auto range = BroadcastIndexesRange<
181-
sizeof...(inputs),
182-
support_noncontiguous_tensors>(out, (*inputs.first)...);
180+
const auto range =
181+
BroadcastIndexesRange<kNumInputs, support_noncontiguous_tensors>(
182+
out, (*inputs.first)...);
183183
auto begin_it = range.begin();
184184
begin_it += begin;
185185
for (; (*begin_it)[0] < end; ++begin_it) {
186186
const auto& indexes = *begin_it;
187-
std::array<CTYPE_COMPUTE, sizeof...(inputs)> loaded_inputs{};
188-
for (const auto idx : c10::irange(sizeof...(inputs))) {
187+
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs{};
188+
for (const auto idx : c10::irange(kNumInputs)) {
189189
loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1]];
190190
}
191191
data_out[indexes[0]] = std::apply(compute_fun, loaded_inputs);
@@ -229,12 +229,14 @@ inline void apply_elementwise_fn_generic_impl(
229229
const Tensor& out,
230230
SupportedTensorDtypes out_dtypes,
231231
Args... inputs) {
232+
static constexpr auto kNumInputs = sizeof...(inputs);
233+
232234
struct InputInfo {
233235
load_to_compute_fn<CTYPE_COMPUTE> load_to_compute;
234236
const char* data_ptr;
235237
ssize_t element_size;
236238
};
237-
std::array<InputInfo, sizeof...(inputs)> inputs_info = {(InputInfo{
239+
std::array<InputInfo, kNumInputs> inputs_info = {(InputInfo{
238240
internal::get_load_to_compute_fn<CTYPE_COMPUTE, op_name>(
239241
ctx, *inputs.first, inputs.second),
240242
reinterpret_cast<const char*>(inputs.first->const_data_ptr()),
@@ -252,15 +254,15 @@ inline void apply_elementwise_fn_generic_impl(
252254
out.numel(),
253255
::executorch::extension::internal::GRAIN_SIZE,
254256
[&](const auto begin, const auto end) {
255-
const auto range = BroadcastIndexesRange<
256-
sizeof...(inputs),
257-
support_noncontiguous_tensors>(out, (*inputs.first)...);
257+
const auto range =
258+
BroadcastIndexesRange<kNumInputs, support_noncontiguous_tensors>(
259+
out, (*inputs.first)...);
258260
auto begin_it = range.begin();
259261
begin_it += begin;
260262
for (; (*begin_it)[0] < end; ++begin_it) {
261263
const auto& indexes = *begin_it;
262-
std::array<CTYPE_COMPUTE, sizeof...(inputs)> loaded_inputs{};
263-
for (const auto idx : c10::irange(sizeof...(inputs))) {
264+
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs{};
265+
for (const auto idx : c10::irange(kNumInputs)) {
264266
const auto& input_info = inputs_info[idx];
265267
loaded_inputs[idx] = input_info.load_to_compute(
266268
&input_info

0 commit comments

Comments
 (0)