@@ -217,35 +217,27 @@ inline bool validate_elementwise_fn_inputs(
217217 return true ;
218218}
219219
220+ template <typename CTYPE_COMPUTE>
221+ struct InputInfo {
222+ load_to_compute_fn<CTYPE_COMPUTE> load_to_compute;
223+ const char * data_ptr;
224+ ssize_t element_size;
225+ };
226+
220227template <
221228 typename CTYPE_COMPUTE,
222- const char * op_name,
223229 bool support_noncontiguous_tensors,
230+ size_t kNumInputs ,
224231 typename Op,
225232 typename ... Args>
226- inline void apply_elementwise_fn_generic_impl (
233+ inline void apply_elementwise_fn_generic_impl_with_load_store_functions (
234+ const std::array<InputInfo<CTYPE_COMPUTE>, kNumInputs >& inputs_info,
235+ store_compute_to_tensor_fn<CTYPE_COMPUTE> store_compute_to_out,
227236 const Op& compute_fun,
228237 KernelRuntimeContext& ctx,
229238 const Tensor& out,
230- SupportedTensorDtypes out_dtypes,
231239 Args... inputs) {
232- constexpr auto kNumInputs = sizeof ...(inputs);
233-
234- struct InputInfo {
235- load_to_compute_fn<CTYPE_COMPUTE> load_to_compute;
236- const char * data_ptr;
237- ssize_t element_size;
238- };
239- std::array<InputInfo, kNumInputs > inputs_info = {(InputInfo{
240- internal::get_load_to_compute_fn<CTYPE_COMPUTE, op_name>(
241- *inputs.first , inputs.second ),
242- reinterpret_cast <const char *>(inputs.first ->const_data_ptr ()),
243- inputs.first ->element_size (),
244- })...};
245-
246- const auto store_compute_to_out =
247- internal::get_store_compute_to_tensor_fn<CTYPE_COMPUTE, op_name>(
248- out, out_dtypes);
240+ static_assert (kNumInputs == sizeof ...(inputs));
249241 char * const data_out = reinterpret_cast <char *>(out.mutable_data_ptr ());
250242 const auto out_element_size = out.element_size ();
251243
@@ -275,6 +267,140 @@ inline void apply_elementwise_fn_generic_impl(
275267 });
276268}
277269
270+ template <
271+ typename CTYPE_COMPUTE,
272+ bool support_noncontiguous_tensors,
273+ size_t kNumInputs ,
274+ typename ... Args>
275+ ET_NOINLINE void
276+ apply_elementwise_fn_generic_impl_with_load_store_functions_and_outlined_op (
277+ const std::array<InputInfo<CTYPE_COMPUTE>, kNumInputs >& inputs_info,
278+ store_compute_to_tensor_fn<CTYPE_COMPUTE> store_compute_to_out,
279+ CTYPE_COMPUTE (*compute_fun)(CTYPE_COMPUTE),
280+ KernelRuntimeContext& ctx,
281+ const Tensor& out,
282+ Args... inputs) {
283+ apply_elementwise_fn_generic_impl_with_load_store_functions<
284+ CTYPE_COMPUTE,
285+ support_noncontiguous_tensors,
286+ kNumInputs >(
287+ inputs_info, store_compute_to_out, compute_fun, ctx, out, inputs...);
288+ }
289+
290+ template <
291+ typename CTYPE_COMPUTE,
292+ bool support_noncontiguous_tensors,
293+ size_t kNumInputs ,
294+ typename ... Args>
295+ ET_NOINLINE void
296+ apply_elementwise_fn_generic_impl_with_load_store_functions_and_outlined_op (
297+ const std::array<InputInfo<CTYPE_COMPUTE>, kNumInputs >& inputs_info,
298+ store_compute_to_tensor_fn<CTYPE_COMPUTE> store_compute_to_out,
299+ CTYPE_COMPUTE (*compute_fun)(CTYPE_COMPUTE, CTYPE_COMPUTE),
300+ KernelRuntimeContext& ctx,
301+ const Tensor& out,
302+ Args... inputs) {
303+ apply_elementwise_fn_generic_impl_with_load_store_functions<
304+ CTYPE_COMPUTE,
305+ support_noncontiguous_tensors,
306+ kNumInputs >(
307+ inputs_info, store_compute_to_out, compute_fun, ctx, out, inputs...);
308+ }
309+
310+ template <
311+ typename CTYPE_COMPUTE,
312+ bool support_noncontiguous_tensors,
313+ size_t kNumInputs ,
314+ typename ... Args>
315+ ET_NOINLINE void
316+ apply_elementwise_fn_generic_impl_with_load_store_functions_and_outlined_op (
317+ const std::array<InputInfo<CTYPE_COMPUTE>, kNumInputs >& inputs_info,
318+ store_compute_to_tensor_fn<CTYPE_COMPUTE> store_compute_to_out,
319+ CTYPE_COMPUTE (*compute_fun)(CTYPE_COMPUTE, CTYPE_COMPUTE, CTYPE_COMPUTE),
320+ KernelRuntimeContext& ctx,
321+ const Tensor& out,
322+ Args... inputs) {
323+ apply_elementwise_fn_generic_impl_with_load_store_functions<
324+ CTYPE_COMPUTE,
325+ support_noncontiguous_tensors,
326+ kNumInputs >(
327+ inputs_info, store_compute_to_out, compute_fun, ctx, out, inputs...);
328+ }
329+
330+ template <
331+ typename CTYPE_COMPUTE,
332+ const char * op_name,
333+ bool support_noncontiguous_tensors,
334+ typename Op,
335+ typename ... Args>
336+ inline void apply_elementwise_fn_generic_impl (
337+ const Op& compute_fun,
338+ KernelRuntimeContext& ctx,
339+ const Tensor& out,
340+ SupportedTensorDtypes out_dtypes,
341+ Args... inputs) {
342+ constexpr auto kNumInputs = sizeof ...(inputs);
343+
344+ std::array<InputInfo<CTYPE_COMPUTE>, kNumInputs > inputs_info = {
345+ (InputInfo<CTYPE_COMPUTE>{
346+ internal::get_load_to_compute_fn<CTYPE_COMPUTE, op_name>(
347+ *inputs.first , inputs.second ),
348+ reinterpret_cast <const char *>(inputs.first ->const_data_ptr ()),
349+ inputs.first ->element_size (),
350+ })...};
351+
352+ const auto store_compute_to_out =
353+ internal::get_store_compute_to_tensor_fn<CTYPE_COMPUTE, op_name>(
354+ out, out_dtypes);
355+ if constexpr (std::is_convertible_v<Op, CTYPE_COMPUTE (*)(CTYPE_COMPUTE)>) {
356+ apply_elementwise_fn_generic_impl_with_load_store_functions_and_outlined_op<
357+ CTYPE_COMPUTE,
358+ support_noncontiguous_tensors,
359+ kNumInputs >(
360+ inputs_info,
361+ store_compute_to_out,
362+ static_cast <CTYPE_COMPUTE (*)(CTYPE_COMPUTE)>(compute_fun),
363+ ctx,
364+ out,
365+ inputs...);
366+ } else if constexpr (std::is_convertible_v<
367+ Op,
368+ CTYPE_COMPUTE (*)(CTYPE_COMPUTE, CTYPE_COMPUTE)>) {
369+ apply_elementwise_fn_generic_impl_with_load_store_functions_and_outlined_op<
370+ CTYPE_COMPUTE,
371+ support_noncontiguous_tensors,
372+ kNumInputs >(
373+ inputs_info,
374+ store_compute_to_out,
375+ static_cast <CTYPE_COMPUTE (*)(CTYPE_COMPUTE, CTYPE_COMPUTE)>(
376+ compute_fun),
377+ ctx,
378+ out,
379+ inputs...);
380+ } else if constexpr (std::is_convertible_v<
381+ Op,
382+ CTYPE_COMPUTE (*)(
383+ CTYPE_COMPUTE, CTYPE_COMPUTE, CTYPE_COMPUTE)>) {
384+ apply_elementwise_fn_generic_impl_with_load_store_functions_and_outlined_op<
385+ CTYPE_COMPUTE,
386+ support_noncontiguous_tensors,
387+ kNumInputs >(
388+ inputs_info,
389+ store_compute_to_out,
390+ static_cast <CTYPE_COMPUTE (*)(
391+ CTYPE_COMPUTE, CTYPE_COMPUTE, CTYPE_COMPUTE)>(compute_fun),
392+ ctx,
393+ out,
394+ inputs...);
395+ } else {
396+ apply_elementwise_fn_generic_impl_with_load_store_functions<
397+ CTYPE_COMPUTE,
398+ support_noncontiguous_tensors,
399+ kNumInputs >(
400+ inputs_info, store_compute_to_out, compute_fun, ctx, out, inputs...);
401+ }
402+ }
403+
278404template <
279405 typename CTYPE_COMPUTE,
280406 const char * op_name,
0 commit comments