4141#include < executorch/extension/kernel_util/meta_programming.h>
4242#include < executorch/extension/kernel_util/type_list.h>
4343#include < executorch/runtime/core/evalue.h>
44+ #include < executorch/runtime/core/event_tracer_hooks.h>
4445#include < executorch/runtime/core/exec_aten/exec_aten.h>
46+ #include < executorch/runtime/kernel/kernel_runtime_context.h>
4547#include < executorch/runtime/kernel/operator_registry.h>
4648#include < cstdlib>
4749#include < memory>
@@ -61,6 +63,49 @@ namespace extension {
6163// internal namespace to avoid conflicts with other extensions.
6264namespace kernel_util_internal {
6365
66+ // Template trait to check if a type is a non-const tensor
67+ template <class T >
68+ struct is_nonconst_tensor : std::false_type {};
69+
70+ template <>
71+ struct is_nonconst_tensor <executorch::aten::Tensor&> : std::true_type {};
72+
73+ // Template trait to check if a type is a non-const tensor
74+ // Count non-const tensors in a typelist
75+ template <class TypeList >
76+ struct count_nonconst_tensors ;
77+
78+ template <>
79+ struct count_nonconst_tensors <typelist<>> {
80+ static constexpr size_t value = 0 ;
81+ };
82+
83+ template <class T >
84+ struct count_nonconst_tensors <typelist<T>> {
85+ static constexpr size_t value = 0 ;
86+ };
87+
88+ template <>
89+ struct count_nonconst_tensors <typelist<executorch::aten::Tensor&>> {
90+ static constexpr size_t value = 1 ;
91+ };
92+
93+ template <class Head , class ... Tail>
94+ struct count_nonconst_tensors <typelist<Head, Tail...>> {
95+ private:
96+ static constexpr size_t tail_tensor_count =
97+ count_nonconst_tensors<typelist<Tail...>>::value;
98+ static constexpr size_t tail_args_count = sizeof ...(Tail);
99+ static constexpr bool is_head_a_tensor = is_nonconst_tensor<Head>::value;
100+ static constexpr bool all_tail_args_are_tensor =
101+ tail_tensor_count == tail_args_count;
102+
103+ public:
104+ static constexpr size_t value = (is_head_a_tensor && all_tail_args_are_tensor)
105+ ? tail_tensor_count + 1
106+ : tail_tensor_count;
107+ };
108+
64109template <class T >
65110struct decay_if_not_tensor final {
66111 using type = std::decay_t <T>;
@@ -110,16 +155,29 @@ struct evalue_to_arg<executorch::aten::ArrayRef<std::optional<T>>> final {
110155 }
111156};
112157
113- template <class Functor , size_t ... evalue_arg_indices, typename ... ArgTypes>
158+ template <
159+ class Functor ,
160+ size_t nonconst_tensors_to_log,
161+ size_t ... evalue_arg_indices,
162+ typename ... ArgTypes>
114163void call_functor_with_args_from_stack (
115164 executorch::runtime::KernelRuntimeContext& ctx,
116165 executorch::runtime::Span<executorch::runtime::EValue*> stack,
117166 std::index_sequence<evalue_arg_indices...>,
118167 typelist<ArgTypes...>*) {
168+ executorch::runtime::internal::EventTracerProfileOpScope
169+ event_tracer_op_scope (ctx.internal_event_tracer (), Functor::func_name_);
170+ EXECUTORCH_SCOPE_PROF (Functor::func_name_);
119171 (*Functor::func_ptr ())(
120172 ctx,
121173 evalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type>::call (
122174 *stack[evalue_arg_indices])...);
175+ constexpr size_t num_inputs =
176+ std::index_sequence<evalue_arg_indices...>::size ();
177+ for (size_t i = num_inputs - nonconst_tensors_to_log; i < num_inputs; ++i) {
178+ executorch::runtime::internal::event_tracer_log_evalue (
179+ ctx.internal_event_tracer (), *stack[i]);
180+ }
123181}
124182
125183} // namespace kernel_util_internal
@@ -154,11 +212,16 @@ struct WrapUnboxedIntoFunctor {
154212 executorch::runtime::Span<executorch::runtime::EValue*> stack) {
155213 constexpr size_t num_inputs =
156214 kernel_util_internal::size<ContextRemovedArgsType>::value;
157- return kernel_util_internal::call_functor_with_args_from_stack<FuncType>(
158- ctx,
159- stack,
160- std::make_index_sequence<num_inputs>(),
161- static_cast <ContextRemovedArgsType*>(nullptr ));
215+ constexpr size_t num_nonconst_tensors =
216+ kernel_util_internal::count_nonconst_tensors<
217+ ContextRemovedArgsType>::value;
218+ static_assert (num_nonconst_tensors == 1 , " Invalid number of inputs" );
219+ return kernel_util_internal::
220+ call_functor_with_args_from_stack<FuncType, num_nonconst_tensors>(
221+ ctx,
222+ stack,
223+ std::make_index_sequence<num_inputs>(),
224+ static_cast <ContextRemovedArgsType*>(nullptr ));
162225 }
163226};
164227
@@ -181,11 +244,14 @@ static executorch::runtime::Kernel make_boxed_kernel(
181244#define EXECUTORCH_LIBRARY (ns, op_name, func ) \
182245 _EXECUTORCH_LIBRARY_IMPL (ns, op_name, func, ET_UID)
183246
184- #define _EXECUTORCH_LIBRARY_IMPL (ns, op_name, func, uid ) \
185- static auto ET_CONCATENATE (res_##ns##_, uid) = \
186- ::executorch::runtime::register_kernel( \
187- ::executorch::extension::make_boxed_kernel ( \
188- #ns " ::" op_name, EXECUTORCH_FN(func)))
247+ #define _EXECUTORCH_LIBRARY_IMPL (ns, op_name, func, uid ) \
248+ static constexpr const char ET_CONCATENATE (name_of_op_, uid)[] = \
249+ #ns " ::" op_name; \
250+ static auto ET_CONCATENATE (res_##ns##_, uid) = \
251+ ::executorch::runtime::register_kernel( \
252+ ::executorch::extension::make_boxed_kernel ( \
253+ #ns " ::" op_name, \
254+ EXECUTORCH_FN (func, ET_CONCATENATE(name_of_op_, uid))))
189255
190256namespace torch {
191257namespace executor {
0 commit comments