Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 78c2a6b

Browse files
authored
Apply functorch-side fix for PythonTLSSnapshot change (#633)
We add PythonTLSSnapshot to the local exclude while functorch is active and remove it when we leave functorch. This is a quick fix; there's something nicer we can do (which is refactor DynamicLayer to not have so many RAII guards that modify the thread-local dispatcher state)
1 parent 3b9e79a commit 78c2a6b

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

functorch/csrc/DynamicLayer.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({
2626
DispatchKey::Functionalize,
2727
// DispatchKey::Batched,
2828
kBatchedKey,
29+
DispatchKey::PythonTLSSnapshot,
2930
DispatchKey::ADInplaceOrView
3031
}) | autograd_dispatch_keyset;
3132

@@ -495,6 +496,9 @@ void dynamicLayerFrontFallback(const c10::OperatorHandle& op, torch::jit::Stack*
495496
if (dynamicLayerStack.size() == 0) {
496497
sanityCheckStack(op, stack);
497498
c10::impl::ExcludeDispatchKeyGuard guard(all_dynlayer_keyset);
499+
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
500+
local_keyset.excluded_ = local_keyset.excluded_.remove(DispatchKey::PythonTLSSnapshot);
501+
c10::impl::ForceDispatchKeyGuard guard2(local_keyset);
498502
op.callBoxed(stack);
499503
return;
500504
}

test/test_eager_transforms.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from functorch import (
2424
grad, vjp, vmap, jacrev, jacfwd, grad_and_value, hessian,
2525
jvp, make_functional, make_functional_with_buffers,
26-
combine_state_for_ensemble,
26+
combine_state_for_ensemble, make_fx
2727
)
2828
from functorch._src.make_functional import (
2929
functional_init, functional_init_with_buffers,
@@ -2112,6 +2112,35 @@ def test_vjp_vjp(self, device):
21122112
y = vjp_fn(x)[0]
21132113
# Honestly IDK what the result here is... but at least it runs
21142114

2115+
def test_make_fx_vmap(self, device):
2116+
def f(x):
2117+
return torch.sin(x)
2118+
inp = torch.randn(5, 3)
2119+
f = vmap(f)
2120+
fx_f = make_fx(f)(inp)
2121+
new_inp = torch.randn(5, 3)
2122+
self.assertEqual(fx_f(new_inp), f(new_inp))
2123+
2124+
def test_make_fx_jacrev(self, device):
2125+
def f(x):
2126+
return x.sin().sum()
2127+
inp = torch.randn(3)
2128+
f = jacrev(jacrev(f))
2129+
fx_f = make_fx(f)(inp)
2130+
new_inp = torch.randn(3)
2131+
self.assertEqual(fx_f(new_inp), f(new_inp))
2132+
2133+
def test_make_fx_vjp(self, device):
2134+
def f(x):
2135+
return torch.sin(x).sum()
2136+
2137+
primals = torch.randn(3)
2138+
_, vjp_fn = vjp(f, primals)
2139+
cotangent = torch.randn(())
2140+
fx_f = make_fx(vjp_fn)(cotangent, True, True)
2141+
new_cotangent = torch.randn(())
2142+
self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))
2143+
21152144

21162145
class TestMakeFunctional(TestCase):
21172146
def test_parameter_tying(self):

0 commit comments

Comments
 (0)