@@ -32,6 +32,20 @@ static void setDynamicLayerFrontBackKeysIncluded(bool included) {
32
32
c10::impl::tls_set_dispatch_key_included (kDynamicLayerBackModeKey , included);
33
33
}
34
34
35
+ struct ForceLocalDispatchKeySet {
36
+ public:
37
+ ForceLocalDispatchKeySet (c10::impl::LocalDispatchKeySet key_set) :
38
+ saved_keyset_ (c10::impl::tls_local_dispatch_key_set()) {
39
+ c10::impl::_force_tls_local_dispatch_key_set (key_set);
40
+ }
41
+ ~ForceLocalDispatchKeySet () {
42
+ c10::impl::_force_tls_local_dispatch_key_set (saved_keyset_);
43
+ }
44
+
45
+ private:
46
+ c10::impl::LocalDispatchKeySet saved_keyset_;
47
+ };
48
+
35
49
DynamicLayer::DynamicLayer (
36
50
DispatchKey key,
37
51
int64_t layerId,
@@ -468,7 +482,7 @@ void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack*
468
482
auto layer = dynamicLayerStack.back ();
469
483
470
484
DispatchKeySet exclude = keysToExcludeWhenEnteringDynamicLayer (layer.key ());
471
- DispatchKeySet include ;
485
+ DispatchKeySet hacky_include ;
472
486
// hack
473
487
if (layer.key () == kBatchedKey ) {
474
488
// Only enable dispatch on kBatchedKey if there are tensors batched
@@ -477,10 +491,12 @@ void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack*
477
491
if (allTensors (args, notBatchedAtCurrentLevel)) {
478
492
exclude = exclude.add (kBatchedKey );
479
493
}
480
- include = include .add (kVmapModeKey );
494
+ hacky_include = hacky_include .add (kVmapModeKey );
481
495
}
482
- c10::impl::ExcludeDispatchKeyGuard exclude_guard (exclude);
483
- c10::impl::IncludeDispatchKeyGuard include_guard (include);
496
+ auto local_keyset = c10::impl::tls_local_dispatch_key_set ();
497
+ local_keyset.excluded_ = local_keyset.excluded_ | exclude;
498
+ local_keyset.included_ = local_keyset.included_ | hacky_include;
499
+ ForceLocalDispatchKeySet guard (local_keyset);
484
500
485
501
// Re-dispatch
486
502
op.callBoxed (stack);
0 commit comments