|
7 | 7 |
|
8 | 8 | namespace at { namespace functorch {
|
9 | 9 |
|
10 |
| -constexpr DispatchKeySet all_dynlayer_keyset = DispatchKeySet({ |
11 |
| - kDynamicLayerFrontModeKey, |
12 |
| - kDynamicLayerBackModeKey, |
13 |
| - kGradWrapperKey, |
14 |
| - DispatchKey::Functionalize, |
15 |
| - // DispatchKey::Batched, |
16 |
| - kBatchedKey, |
17 |
| - DispatchKey::PythonTLSSnapshot, |
18 |
| - DispatchKey::ADInplaceOrView |
19 |
| -}) | autograd_dispatch_keyset; |
| 10 | +static DispatchKeySet get_all_dynlayer_keyset() { |
| 11 | + // NB: FULL_AFTER does not include the dispatch key |
| 12 | + |
| 13 | + // "all dispatch keys between DynamicLayer{Front, Back}Mode, inclusive" |
| 14 | + auto result = |
| 15 | + DispatchKeySet(DispatchKeySet::FULL_AFTER, kDynamicLayerFrontModeKey) - |
| 16 | + DispatchKeySet(DispatchKeySet::FULL_AFTER, kDynamicLayerBackModeKey); |
| 17 | + result = result | DispatchKeySet({kDynamicLayerFrontModeKey}); |
| 18 | + |
| 19 | + // Hack: don't handle the autocast dispatch keys. Their interaction with functorch |
| 20 | + // is weird. |
| 21 | + result = result - autocast_dispatch_keyset; |
| 22 | + |
| 23 | + // Hack: don't handle kVmapModeKey. We need a better way of modeling this. |
| 24 | + // In e.g. grad(vmap(f)), kVmapModeKey makes it so that all random operations, |
| 25 | + // even after we are done handling the vmap layer, error out. |
| 26 | + result = result.remove(kVmapModeKey); |
| 27 | + |
| 28 | + return result; |
| 29 | +} |
| 30 | + |
| 31 | +// TODO: This should be constexpr, but there are some methods |
| 32 | +// of DispatchKeySet that haven't been marked constexpr yet. |
| 33 | +static DispatchKeySet all_dynlayer_keyset = get_all_dynlayer_keyset(); |
20 | 34 |
|
21 | 35 | static DispatchKeySet keysForEnteringDynamicLayer(TransformType key) {
|
22 | 36 | if (key == TransformType::Vmap) {
|
|
0 commit comments