Skip to content

Commit 28d3ef7

Browse files
authored
Allow people to arbitrarily add dispatch keys between DynamicLayer{Front,Back} (#843)
Fixes #842 The Diagnosis ============= As Brian pointed out: For jvp(sub, ...), the chain of dispatch should be: ``` DynamicLayerFrontMode -> at::sub autograd kernel -> DynamicLayerBackMode ``` Instead, what we're doing today is ``` JVP dynamic layer -> at::sub autograd kernel -> at::sub zero_kernel ``` (the zero_tensor kernel errors out, because the inputs are BatchedTensorImpl objects) The Problem ============= functorch's behavior on dispatch keys between DynamicLayerFrontMode and DynamicLayerBack mode should be: - upon entering a dynamic layer (aka Interpreter), we zero out all dispatch keys* between FrontMode and BackMode - then, the dynamic layer (aka Interpreter) decides to re-enable some dispatch keys. For example, JVPInterpreter decides to re-enable the autograd keys - next, we do a dispatcher call, which will end up hitting one of the Autograd keys (in the JVPInterpreter case). The bug is that functorch has a hardcoded list of dispatch keys that it zeros out. This list does not include ZeroTensor, because before pytorch/pytorch#77132, the ZeroTensor key was not between DynamicLayer{Front,Back}Mode. *There is an exception for autocast and vmapmode, described in the next section. The Solution ============ Change functorch to programmatically zero out keys between DynamicLayerBackMode and DynamicLayerFrontMode, with the exception of Autocast and VmapMode. This means that in the future, if someone adds a dispatch key between DynamicLayerBackMode and DynamicLayerFrontMode, we will (probably) be handling it "correctly": the model for dispatch is: - [functorch] -> [regular pytorch dispatcher] - a key like ZeroTensor gets handled in the [regular pytorch dispatcher] section. - functorch transforms get handled in the [functorch] section. We do not change the autocast <-> functorch interaction in this PR (i.e. functorch does not zero it out) because I'm not sure what the correct thing to do here is. We do not change how kVmapMode works because... it needs to be active to ban random operations in transforms later down the line :/ Test Plan ============ Wait for tests
1 parent 693bcee commit 28d3ef7

File tree

2 files changed

+36
-10
lines changed

2 files changed

+36
-10
lines changed

functorch/csrc/Interpreter.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,30 @@
77

88
namespace at { namespace functorch {
99

10-
constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({
11-
kDynamicLayerFrontModeKey,
12-
kDynamicLayerBackModeKey,
13-
kGradWrapperKey,
14-
DispatchKey::Functionalize,
15-
// DispatchKey::Batched,
16-
kBatchedKey,
17-
DispatchKey::PythonTLSSnapshot,
18-
DispatchKey::ADInplaceOrView
19-
}) | autograd_dispatch_keyset;
10+
static DispatchKeySet get_all_dynlayer_keyset() {
11+
// NB: FULL_AFTER does not include the dispatch key
12+
13+
// "all dispatch keys between DynamicLayer{Front, Back}Mode, inclusive"
14+
auto result =
15+
DispatchKeySet(DispatchKeySet::FULL_AFTER, kDynamicLayerFrontModeKey) -
16+
DispatchKeySet(DispatchKeySet::FULL_AFTER, kDynamicLayerBackModeKey);
17+
result = result | DispatchKeySet({kDynamicLayerFrontModeKey});
18+
19+
// Hack: don't handle the autocast dispatch keys. Their interaction with functorch
20+
// is weird.
21+
result = result - autocast_dispatch_keyset;
22+
23+
// Hack: don't handle kVmapModeKey. We need a better way of modeling this.
24+
// In e.g. grad(vmap(f)), kVmapModeKey makes it so that all random operations,
25+
// even after we are done handling the vmap layer, error out.
26+
result = result.remove(kVmapModeKey);
27+
28+
return result;
29+
}
30+
31+
// TODO: This should be constexpr, but there are some methods
32+
// of DispatchKeySet that haven't been marked constexpr yet.
33+
static DispatchKeySet all_dynlayer_keyset = get_all_dynlayer_keyset();
2034

2135
static DispatchKeySet keysForEnteringDynamicLayer(TransformType key) {
2236
if (key == TransformType::Vmap) {

test/test_eager_transforms.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1999,6 +1999,18 @@ def backward(ctx, gx):
19991999
gx, = torch.autograd.grad(y, x)
20002000
self.assertEqual(gx, x.cos())
20012001

2002+
def test_zerotensor_vmapjvp_interaction(self, device):
2003+
dummy = torch.ones(4, 1)
2004+
x = torch.randn(4, 2)
2005+
x_tangent = torch.randn(2)
2006+
2007+
def push_jvp(dummy, x):
2008+
result = jvp(torch.cov, (x,), (x_tangent,))
2009+
return result
2010+
2011+
# Should not error
2012+
vmap(vmap(push_jvp, (0, None)))(dummy, x)
2013+
20022014

20032015
class TestCustomFunction(TestCase):
20042016
@unittest.skipIf(IS_WINDOWS, "Prototype of custom_vjp doesn't link on windows")

0 commit comments

Comments
 (0)