41
41
#include < executorch/extension/kernel_util/meta_programming.h>
42
42
#include < executorch/extension/kernel_util/type_list.h>
43
43
#include < executorch/runtime/core/evalue.h>
44
+ #include < executorch/runtime/core/event_tracer_hooks.h>
44
45
#include < executorch/runtime/core/exec_aten/exec_aten.h>
46
+ #include < executorch/runtime/kernel/kernel_runtime_context.h>
45
47
#include < executorch/runtime/kernel/operator_registry.h>
46
48
#include < cstdlib>
47
49
#include < memory>
@@ -61,6 +63,49 @@ namespace extension {
61
63
// internal namespace to avoid conflicts with other extensions.
62
64
namespace kernel_util_internal {
63
65
66
+ // Template trait to check if a type is a non-const tensor
67
+ template <class T >
68
+ struct is_nonconst_tensor : std::false_type {};
69
+
70
+ template <>
71
+ struct is_nonconst_tensor <executorch::aten::Tensor&> : std::true_type {};
72
+
73
+ // Template trait to check if a type is a non-const tensor
74
+ // Count non-const tensors in a typelist
75
+ template <class TypeList >
76
+ struct count_nonconst_tensors ;
77
+
78
+ template <>
79
+ struct count_nonconst_tensors <typelist<>> {
80
+ static constexpr size_t value = 0 ;
81
+ };
82
+
83
+ template <class T >
84
+ struct count_nonconst_tensors <typelist<T>> {
85
+ static constexpr size_t value = 0 ;
86
+ };
87
+
88
+ template <>
89
+ struct count_nonconst_tensors <typelist<executorch::aten::Tensor&>> {
90
+ static constexpr size_t value = 1 ;
91
+ };
92
+
93
+ template <class Head , class ... Tail>
94
+ struct count_nonconst_tensors <typelist<Head, Tail...>> {
95
+ private:
96
+ static constexpr size_t tail_tensor_count =
97
+ count_nonconst_tensors<typelist<Tail...>>::value;
98
+ static constexpr size_t tail_args_count = sizeof ...(Tail);
99
+ static constexpr bool is_head_a_tensor = is_nonconst_tensor<Head>::value;
100
+ static constexpr bool all_tail_args_are_tensor =
101
+ tail_tensor_count == tail_args_count;
102
+
103
+ public:
104
+ static constexpr size_t value = (is_head_a_tensor && all_tail_args_are_tensor)
105
+ ? tail_tensor_count + 1
106
+ : tail_tensor_count;
107
+ };
108
+
64
109
template <class T >
65
110
struct decay_if_not_tensor final {
66
111
using type = std::decay_t <T>;
@@ -110,16 +155,29 @@ struct evalue_to_arg<executorch::aten::ArrayRef<std::optional<T>>> final {
110
155
}
111
156
};
112
157
113
- template <class Functor , size_t ... evalue_arg_indices, typename ... ArgTypes>
158
+ template <
159
+ class Functor ,
160
+ size_t nonconst_tensors_to_log,
161
+ size_t ... evalue_arg_indices,
162
+ typename ... ArgTypes>
114
163
void call_functor_with_args_from_stack (
115
164
executorch::runtime::KernelRuntimeContext& ctx,
116
165
executorch::runtime::Span<executorch::runtime::EValue*> stack,
117
166
std::index_sequence<evalue_arg_indices...>,
118
167
typelist<ArgTypes...>*) {
168
+ executorch::runtime::internal::EventTracerProfileOpScope
169
+ event_tracer_op_scope (ctx.internal_event_tracer (), Functor::func_name_);
170
+ EXECUTORCH_SCOPE_PROF (Functor::func_name_);
119
171
(*Functor::func_ptr ())(
120
172
ctx,
121
173
evalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type>::call (
122
174
*stack[evalue_arg_indices])...);
175
+ constexpr size_t num_inputs =
176
+ std::index_sequence<evalue_arg_indices...>::size ();
177
+ for (size_t i = num_inputs - nonconst_tensors_to_log; i < num_inputs; ++i) {
178
+ executorch::runtime::internal::event_tracer_log_evalue (
179
+ ctx.internal_event_tracer (), *stack[i]);
180
+ }
123
181
}
124
182
125
183
} // namespace kernel_util_internal
@@ -154,11 +212,16 @@ struct WrapUnboxedIntoFunctor {
154
212
executorch::runtime::Span<executorch::runtime::EValue*> stack) {
155
213
constexpr size_t num_inputs =
156
214
kernel_util_internal::size<ContextRemovedArgsType>::value;
157
- return kernel_util_internal::call_functor_with_args_from_stack<FuncType>(
158
- ctx,
159
- stack,
160
- std::make_index_sequence<num_inputs>(),
161
- static_cast <ContextRemovedArgsType*>(nullptr ));
215
+ constexpr size_t num_nonconst_tensors =
216
+ kernel_util_internal::count_nonconst_tensors<
217
+ ContextRemovedArgsType>::value;
218
+ static_assert (num_nonconst_tensors == 1 , " Invalid number of inputs" );
219
+ return kernel_util_internal::
220
+ call_functor_with_args_from_stack<FuncType, num_nonconst_tensors>(
221
+ ctx,
222
+ stack,
223
+ std::make_index_sequence<num_inputs>(),
224
+ static_cast <ContextRemovedArgsType*>(nullptr ));
162
225
}
163
226
};
164
227
@@ -181,11 +244,14 @@ static executorch::runtime::Kernel make_boxed_kernel(
181
244
#define EXECUTORCH_LIBRARY (ns, op_name, func ) \
182
245
_EXECUTORCH_LIBRARY_IMPL (ns, op_name, func, ET_UID)
183
246
184
- #define _EXECUTORCH_LIBRARY_IMPL (ns, op_name, func, uid ) \
185
- static auto ET_CONCATENATE (res_##ns##_, uid) = \
186
- ::executorch::runtime::register_kernel( \
187
- ::executorch::extension::make_boxed_kernel ( \
188
- #ns " ::" op_name, EXECUTORCH_FN(func)))
247
+ #define _EXECUTORCH_LIBRARY_IMPL (ns, op_name, func, uid ) \
248
+ static constexpr const char ET_CONCATENATE (name_of_op_, uid)[] = \
249
+ #ns " ::" op_name; \
250
+ static auto ET_CONCATENATE (res_##ns##_, uid) = \
251
+ ::executorch::runtime::register_kernel( \
252
+ ::executorch::extension::make_boxed_kernel ( \
253
+ #ns " ::" op_name, \
254
+ EXECUTORCH_FN (func, ET_CONCATENATE(name_of_op_, uid))))
189
255
190
256
namespace torch {
191
257
namespace executor {
0 commit comments