Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 7b3c756

Browse files
committed
Introduce ForceLocalDispatchKeySet
IncludeDispatchKeyGuard and ExcludeDispatchKeyGuard were getting too confusing
1 parent b5d80d7 commit 7b3c756

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

functorch/csrc/DynamicLayer.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,20 @@ static void setDynamicLayerFrontBackKeysIncluded(bool included) {
3232
c10::impl::tls_set_dispatch_key_included(kDynamicLayerBackModeKey, included);
3333
}
3434

35+
struct ForceLocalDispatchKeySet {
36+
public:
37+
ForceLocalDispatchKeySet(c10::impl::LocalDispatchKeySet key_set) :
38+
saved_keyset_(c10::impl::tls_local_dispatch_key_set()) {
39+
c10::impl::_force_tls_local_dispatch_key_set(key_set);
40+
}
41+
~ForceLocalDispatchKeySet() {
42+
c10::impl::_force_tls_local_dispatch_key_set(saved_keyset_);
43+
}
44+
45+
private:
46+
c10::impl::LocalDispatchKeySet saved_keyset_;
47+
};
48+
3549
DynamicLayer::DynamicLayer(
3650
DispatchKey key,
3751
int64_t layerId,
@@ -468,7 +482,7 @@ void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack*
468482
auto layer = dynamicLayerStack.back();
469483

470484
DispatchKeySet exclude = keysToExcludeWhenEnteringDynamicLayer(layer.key());
471-
DispatchKeySet include;
485+
DispatchKeySet hacky_include;
472486
// hack
473487
if (layer.key() == kBatchedKey) {
474488
// Only enable dispatch on kBatchedKey if there are tensors batched
@@ -477,10 +491,12 @@ void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack*
477491
if (allTensors(args, notBatchedAtCurrentLevel)) {
478492
exclude = exclude.add(kBatchedKey);
479493
}
480-
include = include.add(kVmapModeKey);
494+
hacky_include = hacky_include.add(kVmapModeKey);
481495
}
482-
c10::impl::ExcludeDispatchKeyGuard exclude_guard(exclude);
483-
c10::impl::IncludeDispatchKeyGuard include_guard(include);
496+
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
497+
local_keyset.excluded_ = local_keyset.excluded_ | exclude;
498+
local_keyset.included_ = local_keyset.included_ | hacky_include;
499+
ForceLocalDispatchKeySet guard(local_keyset);
484500

485501
// Re-dispatch
486502
op.callBoxed(stack);

0 commit comments

Comments
 (0)