Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 8eb6af6

Browse files
committed
Some more utilities
1 parent ae4b3bd commit 8eb6af6

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

functorch/csrc/DynamicLayer.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,12 @@ static DispatchKeySet keysForEnteringDynamicLayer(DispatchKey key) {
445445
}
446446
}
447447

448+
static void dump_local_tls() {
449+
auto tls = c10::impl::tls_local_dispatch_key_set();
450+
std::cout << "[Local Include] " << tls.included_ << std::endl;
451+
std::cout << "[Local Exclude] " << tls.excluded_ << std::endl;
452+
}
453+
448454
static DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(DispatchKey key) {
449455
DispatchKeySet exclude = all_dynlayer_keyset;
450456
exclude = exclude.remove(kDynamicLayerBackModeKey);
@@ -457,6 +463,7 @@ void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack*
457463
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
458464
if (c10::show_dispatch_trace_enabled()) {
459465
std::cout << dynamicLayerStack << std::endl;
466+
dump_local_tls();
460467
}
461468
#endif
462469
if (dynamicLayerStack.size() == 0) {
@@ -498,6 +505,12 @@ void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack*
498505
local_keyset.included_ = local_keyset.included_ | hacky_include;
499506
ForceLocalDispatchKeySet guard(local_keyset);
500507

508+
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
509+
if (c10::show_dispatch_trace_enabled()) {
510+
dump_local_tls();
511+
}
512+
#endif
513+
501514
// Re-dispatch
502515
op.callBoxed(stack);
503516
}
@@ -594,6 +607,12 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
594607
c10::impl::_force_tls_local_dispatch_key_set(keyset);
595608
setDynamicLayerFrontBackKeysIncluded(true);
596609

610+
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
611+
if (c10::show_dispatch_trace_enabled()) {
612+
dump_local_tls();
613+
}
614+
#endif
615+
597616
// Re-dispatch
598617
if (cur_key == DispatchKey::Autograd && *prev_grad_mode == false) {
599618
// See NOTE [grad and vjp interaction with no_grad]

functorch/csrc/init.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,12 @@ static void dump_dls() {
233233
std::cout << getDynamicLayerStack() << std::endl;
234234
}
235235

236+
static void dump_local_tls() {
237+
auto tls = c10::impl::tls_local_dispatch_key_set();
238+
std::cout << "[Local Include] " << tls.included_ << std::endl;
239+
std::cout << "[Local Exclude] " << tls.excluded_ << std::endl;
240+
}
241+
236242
} // namespace functorch
237243
}
238244

@@ -268,6 +274,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
268274
m.def("tls_set_vmap_excluded", &at::functorch::tls_set_vmap_excluded);
269275
m.def("tls_set_is_included", &at::functorch::tls_set_is_included);
270276
m.def("dump_dls", &at::functorch::dump_dls);
277+
m.def("dump_local_tls", &at::functorch::dump_local_tls);
271278
at::functorch::initPointwiseOperatorCompileCacheBindings(m.ptr());
272279
at::functorch::initCompileCacheBindings(m.ptr());
273280
initDispatchBindings(m.ptr());

0 commit comments

Comments
 (0)