Skip to content

Commit b8da6f6

Browse files
committed
WIP: outline mixed dtype kernel
ghstack-source-id: 2c0dc4e ghstack-comment-id: 3033889715 Pull-Request-resolved: #12219
1 parent be9693a commit b8da6f6

File tree

1 file changed

+146
-20
lines changed

1 file changed

+146
-20
lines changed

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 146 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -217,35 +217,27 @@ inline bool validate_elementwise_fn_inputs(
217217
return true;
218218
}
219219

220+
template <typename CTYPE_COMPUTE>
221+
struct InputInfo {
222+
load_to_compute_fn<CTYPE_COMPUTE> load_to_compute;
223+
const char* data_ptr;
224+
ssize_t element_size;
225+
};
226+
220227
template <
221228
typename CTYPE_COMPUTE,
222-
const char* op_name,
223229
bool support_noncontiguous_tensors,
230+
size_t kNumInputs,
224231
typename Op,
225232
typename... Args>
226-
inline void apply_elementwise_fn_generic_impl(
233+
inline void apply_elementwise_fn_generic_impl_with_load_store_functions(
234+
const std::array<InputInfo<CTYPE_COMPUTE>, kNumInputs>& inputs_info,
235+
store_compute_to_tensor_fn<CTYPE_COMPUTE> store_compute_to_out,
227236
const Op& compute_fun,
228237
KernelRuntimeContext& ctx,
229238
const Tensor& out,
230-
SupportedTensorDtypes out_dtypes,
231239
Args... inputs) {
232-
constexpr auto kNumInputs = sizeof...(inputs);
233-
234-
struct InputInfo {
235-
load_to_compute_fn<CTYPE_COMPUTE> load_to_compute;
236-
const char* data_ptr;
237-
ssize_t element_size;
238-
};
239-
std::array<InputInfo, kNumInputs> inputs_info = {(InputInfo{
240-
internal::get_load_to_compute_fn<CTYPE_COMPUTE, op_name>(
241-
*inputs.first, inputs.second),
242-
reinterpret_cast<const char*>(inputs.first->const_data_ptr()),
243-
inputs.first->element_size(),
244-
})...};
245-
246-
const auto store_compute_to_out =
247-
internal::get_store_compute_to_tensor_fn<CTYPE_COMPUTE, op_name>(
248-
out, out_dtypes);
240+
static_assert(kNumInputs == sizeof...(inputs));
249241
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
250242
const auto out_element_size = out.element_size();
251243

@@ -275,6 +267,140 @@ inline void apply_elementwise_fn_generic_impl(
275267
});
276268
}
277269

270+
template <
271+
typename CTYPE_COMPUTE,
272+
bool support_noncontiguous_tensors,
273+
size_t kNumInputs,
274+
typename... Args>
275+
ET_NOINLINE void
276+
apply_elementwise_fn_generic_impl_with_load_store_functions_and_outlined_op(
277+
const std::array<InputInfo<CTYPE_COMPUTE>, kNumInputs>& inputs_info,
278+
store_compute_to_tensor_fn<CTYPE_COMPUTE> store_compute_to_out,
279+
CTYPE_COMPUTE (*compute_fun)(CTYPE_COMPUTE),
280+
KernelRuntimeContext& ctx,
281+
const Tensor& out,
282+
Args... inputs) {
283+
apply_elementwise_fn_generic_impl_with_load_store_functions<
284+
CTYPE_COMPUTE,
285+
support_noncontiguous_tensors,
286+
kNumInputs>(
287+
inputs_info, store_compute_to_out, compute_fun, ctx, out, inputs...);
288+
}
289+
290+
template <
291+
typename CTYPE_COMPUTE,
292+
bool support_noncontiguous_tensors,
293+
size_t kNumInputs,
294+
typename... Args>
295+
ET_NOINLINE void
296+
apply_elementwise_fn_generic_impl_with_load_store_functions_and_outlined_op(
297+
const std::array<InputInfo<CTYPE_COMPUTE>, kNumInputs>& inputs_info,
298+
store_compute_to_tensor_fn<CTYPE_COMPUTE> store_compute_to_out,
299+
CTYPE_COMPUTE (*compute_fun)(CTYPE_COMPUTE, CTYPE_COMPUTE),
300+
KernelRuntimeContext& ctx,
301+
const Tensor& out,
302+
Args... inputs) {
303+
apply_elementwise_fn_generic_impl_with_load_store_functions<
304+
CTYPE_COMPUTE,
305+
support_noncontiguous_tensors,
306+
kNumInputs>(
307+
inputs_info, store_compute_to_out, compute_fun, ctx, out, inputs...);
308+
}
309+
310+
template <
311+
typename CTYPE_COMPUTE,
312+
bool support_noncontiguous_tensors,
313+
size_t kNumInputs,
314+
typename... Args>
315+
ET_NOINLINE void
316+
apply_elementwise_fn_generic_impl_with_load_store_functions_and_outlined_op(
317+
const std::array<InputInfo<CTYPE_COMPUTE>, kNumInputs>& inputs_info,
318+
store_compute_to_tensor_fn<CTYPE_COMPUTE> store_compute_to_out,
319+
CTYPE_COMPUTE (*compute_fun)(CTYPE_COMPUTE, CTYPE_COMPUTE, CTYPE_COMPUTE),
320+
KernelRuntimeContext& ctx,
321+
const Tensor& out,
322+
Args... inputs) {
323+
apply_elementwise_fn_generic_impl_with_load_store_functions<
324+
CTYPE_COMPUTE,
325+
support_noncontiguous_tensors,
326+
kNumInputs>(
327+
inputs_info, store_compute_to_out, compute_fun, ctx, out, inputs...);
328+
}
329+
330+
template <
331+
typename CTYPE_COMPUTE,
332+
const char* op_name,
333+
bool support_noncontiguous_tensors,
334+
typename Op,
335+
typename... Args>
336+
inline void apply_elementwise_fn_generic_impl(
337+
const Op& compute_fun,
338+
KernelRuntimeContext& ctx,
339+
const Tensor& out,
340+
SupportedTensorDtypes out_dtypes,
341+
Args... inputs) {
342+
constexpr auto kNumInputs = sizeof...(inputs);
343+
344+
std::array<InputInfo<CTYPE_COMPUTE>, kNumInputs> inputs_info = {
345+
(InputInfo<CTYPE_COMPUTE>{
346+
internal::get_load_to_compute_fn<CTYPE_COMPUTE, op_name>(
347+
*inputs.first, inputs.second),
348+
reinterpret_cast<const char*>(inputs.first->const_data_ptr()),
349+
inputs.first->element_size(),
350+
})...};
351+
352+
const auto store_compute_to_out =
353+
internal::get_store_compute_to_tensor_fn<CTYPE_COMPUTE, op_name>(
354+
out, out_dtypes);
355+
if constexpr (std::is_convertible_v<Op, CTYPE_COMPUTE (*)(CTYPE_COMPUTE)>) {
356+
apply_elementwise_fn_generic_impl_with_load_store_functions_and_outlined_op<
357+
CTYPE_COMPUTE,
358+
support_noncontiguous_tensors,
359+
kNumInputs>(
360+
inputs_info,
361+
store_compute_to_out,
362+
static_cast<CTYPE_COMPUTE (*)(CTYPE_COMPUTE)>(compute_fun),
363+
ctx,
364+
out,
365+
inputs...);
366+
} else if constexpr (std::is_convertible_v<
367+
Op,
368+
CTYPE_COMPUTE (*)(CTYPE_COMPUTE, CTYPE_COMPUTE)>) {
369+
apply_elementwise_fn_generic_impl_with_load_store_functions_and_outlined_op<
370+
CTYPE_COMPUTE,
371+
support_noncontiguous_tensors,
372+
kNumInputs>(
373+
inputs_info,
374+
store_compute_to_out,
375+
static_cast<CTYPE_COMPUTE (*)(CTYPE_COMPUTE, CTYPE_COMPUTE)>(
376+
compute_fun),
377+
ctx,
378+
out,
379+
inputs...);
380+
} else if constexpr (std::is_convertible_v<
381+
Op,
382+
CTYPE_COMPUTE (*)(
383+
CTYPE_COMPUTE, CTYPE_COMPUTE, CTYPE_COMPUTE)>) {
384+
apply_elementwise_fn_generic_impl_with_load_store_functions_and_outlined_op<
385+
CTYPE_COMPUTE,
386+
support_noncontiguous_tensors,
387+
kNumInputs>(
388+
inputs_info,
389+
store_compute_to_out,
390+
static_cast<CTYPE_COMPUTE (*)(
391+
CTYPE_COMPUTE, CTYPE_COMPUTE, CTYPE_COMPUTE)>(compute_fun),
392+
ctx,
393+
out,
394+
inputs...);
395+
} else {
396+
apply_elementwise_fn_generic_impl_with_load_store_functions<
397+
CTYPE_COMPUTE,
398+
support_noncontiguous_tensors,
399+
kNumInputs>(
400+
inputs_info, store_compute_to_out, compute_fun, ctx, out, inputs...);
401+
}
402+
}
403+
278404
template <
279405
typename CTYPE_COMPUTE,
280406
const char* op_name,

0 commit comments

Comments
 (0)