Skip to content

Commit 855c083

Browse files
authored
Allow custom ops to log outputs and profiling events for etdump
Differential Revision: D81131610 Pull Request resolved: #14010
1 parent 1d4c6ba commit 855c083

File tree

2 files changed

+86
-17
lines changed

2 files changed

+86
-17
lines changed

extension/kernel_util/make_boxed_from_unboxed_functor.h

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141
#include <executorch/extension/kernel_util/meta_programming.h>
4242
#include <executorch/extension/kernel_util/type_list.h>
4343
#include <executorch/runtime/core/evalue.h>
44+
#include <executorch/runtime/core/event_tracer_hooks.h>
4445
#include <executorch/runtime/core/exec_aten/exec_aten.h>
46+
#include <executorch/runtime/kernel/kernel_runtime_context.h>
4547
#include <executorch/runtime/kernel/operator_registry.h>
4648
#include <cstdlib>
4749
#include <memory>
@@ -61,6 +63,49 @@ namespace extension {
6163
// internal namespace to avoid conflicts with other extensions.
6264
namespace kernel_util_internal {
6365

66+
// Template trait to check if a type is a non-const tensor
67+
template <class T>
68+
struct is_nonconst_tensor : std::false_type {};
69+
70+
template <>
71+
struct is_nonconst_tensor<executorch::aten::Tensor&> : std::true_type {};
72+
73+
// Template trait to check if a type is a non-const tensor
74+
// Count non-const tensors in a typelist
75+
template <class TypeList>
76+
struct count_nonconst_tensors;
77+
78+
template <>
79+
struct count_nonconst_tensors<typelist<>> {
80+
static constexpr size_t value = 0;
81+
};
82+
83+
template <class T>
84+
struct count_nonconst_tensors<typelist<T>> {
85+
static constexpr size_t value = 0;
86+
};
87+
88+
template <>
89+
struct count_nonconst_tensors<typelist<executorch::aten::Tensor&>> {
90+
static constexpr size_t value = 1;
91+
};
92+
93+
template <class Head, class... Tail>
94+
struct count_nonconst_tensors<typelist<Head, Tail...>> {
95+
private:
96+
static constexpr size_t tail_tensor_count =
97+
count_nonconst_tensors<typelist<Tail...>>::value;
98+
static constexpr size_t tail_args_count = sizeof...(Tail);
99+
static constexpr bool is_head_a_tensor = is_nonconst_tensor<Head>::value;
100+
static constexpr bool all_tail_args_are_tensor =
101+
tail_tensor_count == tail_args_count;
102+
103+
public:
104+
static constexpr size_t value = (is_head_a_tensor && all_tail_args_are_tensor)
105+
? tail_tensor_count + 1
106+
: tail_tensor_count;
107+
};
108+
64109
template <class T>
65110
struct decay_if_not_tensor final {
66111
using type = std::decay_t<T>;
@@ -110,16 +155,29 @@ struct evalue_to_arg<executorch::aten::ArrayRef<std::optional<T>>> final {
110155
}
111156
};
112157

113-
template <class Functor, size_t... evalue_arg_indices, typename... ArgTypes>
158+
template <
159+
class Functor,
160+
size_t nonconst_tensors_to_log,
161+
size_t... evalue_arg_indices,
162+
typename... ArgTypes>
114163
void call_functor_with_args_from_stack(
115164
executorch::runtime::KernelRuntimeContext& ctx,
116165
executorch::runtime::Span<executorch::runtime::EValue*> stack,
117166
std::index_sequence<evalue_arg_indices...>,
118167
typelist<ArgTypes...>*) {
168+
executorch::runtime::internal::EventTracerProfileOpScope
169+
event_tracer_op_scope(ctx.internal_event_tracer(), Functor::func_name_);
170+
EXECUTORCH_SCOPE_PROF(Functor::func_name_);
119171
(*Functor::func_ptr())(
120172
ctx,
121173
evalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type>::call(
122174
*stack[evalue_arg_indices])...);
175+
constexpr size_t num_inputs =
176+
std::index_sequence<evalue_arg_indices...>::size();
177+
for (size_t i = num_inputs - nonconst_tensors_to_log; i < num_inputs; ++i) {
178+
executorch::runtime::internal::event_tracer_log_evalue(
179+
ctx.internal_event_tracer(), *stack[i]);
180+
}
123181
}
124182

125183
} // namespace kernel_util_internal
@@ -154,11 +212,16 @@ struct WrapUnboxedIntoFunctor {
154212
executorch::runtime::Span<executorch::runtime::EValue*> stack) {
155213
constexpr size_t num_inputs =
156214
kernel_util_internal::size<ContextRemovedArgsType>::value;
157-
return kernel_util_internal::call_functor_with_args_from_stack<FuncType>(
158-
ctx,
159-
stack,
160-
std::make_index_sequence<num_inputs>(),
161-
static_cast<ContextRemovedArgsType*>(nullptr));
215+
constexpr size_t num_nonconst_tensors =
216+
kernel_util_internal::count_nonconst_tensors<
217+
ContextRemovedArgsType>::value;
218+
static_assert(num_nonconst_tensors == 1, "Invalid number of inputs");
219+
return kernel_util_internal::
220+
call_functor_with_args_from_stack<FuncType, num_nonconst_tensors>(
221+
ctx,
222+
stack,
223+
std::make_index_sequence<num_inputs>(),
224+
static_cast<ContextRemovedArgsType*>(nullptr));
162225
}
163226
};
164227

@@ -181,11 +244,14 @@ static executorch::runtime::Kernel make_boxed_kernel(
181244
#define EXECUTORCH_LIBRARY(ns, op_name, func) \
182245
_EXECUTORCH_LIBRARY_IMPL(ns, op_name, func, ET_UID)
183246

184-
#define _EXECUTORCH_LIBRARY_IMPL(ns, op_name, func, uid) \
185-
static auto ET_CONCATENATE(res_##ns##_, uid) = \
186-
::executorch::runtime::register_kernel( \
187-
::executorch::extension::make_boxed_kernel( \
188-
#ns "::" op_name, EXECUTORCH_FN(func)))
247+
#define _EXECUTORCH_LIBRARY_IMPL(ns, op_name, func, uid) \
248+
static constexpr const char ET_CONCATENATE(name_of_op_, uid)[] = \
249+
#ns "::" op_name; \
250+
static auto ET_CONCATENATE(res_##ns##_, uid) = \
251+
::executorch::runtime::register_kernel( \
252+
::executorch::extension::make_boxed_kernel( \
253+
#ns "::" op_name, \
254+
EXECUTORCH_FN(func, ET_CONCATENATE(name_of_op_, uid))))
189255

190256
namespace torch {
191257
namespace executor {

extension/kernel_util/meta_programming.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,13 @@ template <class T>
3232
using is_function_type_t = typename is_function_type<T>::type;
3333

3434
// A compile-time wrapper around a function pointer
35-
template <class FuncType_, FuncType_* func_ptr_>
35+
template <class FuncType_, FuncType_* func_ptr_, const char* func_name>
3636
struct CompileTimeFunctionPointer final {
3737
static_assert(
3838
is_function_type<FuncType_>::value,
3939
"EXECUTORCH_FN can only wrap function types.");
4040
using FuncType = FuncType_;
41+
static constexpr const char* func_name_ = func_name;
4142

4243
static constexpr FuncType* func_ptr() {
4344
return func_ptr_;
@@ -47,15 +48,17 @@ struct CompileTimeFunctionPointer final {
4748
// Check if a given type is a compile-time function pointer
4849
template <class T>
4950
struct is_compile_time_function_pointer : std::false_type {};
50-
template <class FuncType, FuncType* func_ptr>
51+
template <class FuncType, FuncType* func_ptr, const char* func_name>
5152
struct is_compile_time_function_pointer<
52-
CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {};
53+
CompileTimeFunctionPointer<FuncType, func_ptr, func_name>>
54+
: std::true_type {};
5355

54-
#define EXECUTORCH_FN_TYPE(func) \
56+
#define EXECUTORCH_FN_TYPE(func, name) \
5557
::executorch::extension::kernel_util_internal::CompileTimeFunctionPointer< \
5658
std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \
57-
func>
58-
#define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)()
59+
func, \
60+
name>
61+
#define EXECUTORCH_FN(func, name) EXECUTORCH_FN_TYPE(func, name)()
5962

6063
/**
6164
* strip_class: helper to remove the class type from pointers to `operator()`.

0 commit comments

Comments
 (0)