Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 17f407c

Browse files
authored
no_grad fix (#179)
Fixes #13 Case 1: grad gets called inside torch.no_grad. - grad should ignore torch.no_grad because it's "creating a new level of autograd above the current level" - Another way to think about this is that grad(f) is a "function transform": its result should not be affected by context managers that are outside of the function f Case 2: torch.no_grad gets called inside `grad` - grad should respect torch.no_grad See NOTE [grad and vjp interaction with no_grad] for implementation strategy. It unfortunately involves a mode. Test Plan: - Many tests
1 parent 4038cc4 commit 17f407c

File tree

5 files changed

+181
-72
lines changed

5 files changed

+181
-72
lines changed

functorch/_src/eager_transforms.py

Lines changed: 91 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -88,27 +88,75 @@ def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, creat
8888
for gi, inp in zip(grad_inputs, inputs))
8989
return grad_inputs
9090

91+
# NOTE [grad and vjp interaction with no_grad]
92+
#
93+
# def f(x):
94+
# with torch.no_grad():
95+
# c = x ** 2
96+
# return x - c
97+
#
98+
# The thing to consider is if enable_grad is on/off before grad gets called.
99+
#
100+
# Case 1: enable_grad is on.
101+
# grad(f)(x)
102+
# In this case, `grad` should respect the inner torch.no_grad.
103+
#
104+
# Case 2: enable_grad is off
105+
# with torch.no_grad():
106+
# grad(f)(x)
107+
# In this case, `grad` should respect the inner torch.no_grad, but not the
108+
# outer one. This is because `grad` is a "function transform": its result
109+
# should not depend on the result of a context manager outside of `f`.
110+
#
111+
# This gives us the following desired behavior:
112+
# - (nested) grad transforms must obey torch.no_grad inside them
113+
# - (nested) grad transforms should not obey torch.no_grad outside them
114+
#
115+
# To achieve this behavior, upon entering grad/vjp:
116+
# - we save the current ("previous") is_grad_enabled (*)
117+
# - we unconditionally enable grad.
118+
#
119+
# Inside DynamicLayerBackFallback, when we're temporarily popping `grad` layer
120+
# off the stack:
121+
# - if grad_mode is disabled, then we do nothing. (there is a torch.no_grad
122+
# active, all subsequent grad transforms must obey it).
123+
# - if grad_mode is enabled, and the previous is_grad_enabled (*) is False,
124+
# then we temporarily restore the previous `is_grad_enabled`. This is
125+
# because we're crossing the boundary from a `grad` outside the
126+
# no_grad to a `grad` inside the no_grad.
127+
#
128+
# NB: vjp has some interesting behavior because the vjp's callable can be called
129+
# under a different grad_mode than the forward computation...
130+
#
131+
# TODO: forward-mode AD: does it also respect no_grad? What does that mean
132+
# for our jvp transform?
133+
134+
91135
# How do we increment and decrement the nesting? I don't think we can.
92136
def vjp(f, *primals):
93137
level = _grad_increment_nesting()
94138
try:
95-
primals = _wrap_all_tensors(primals, level)
96-
diff_primals = _create_differentiable(primals, level)
97-
primals_out = f(*diff_primals)
98-
99-
results = _undo_create_differentiable(primals_out, level)
100-
flat_diff_primals, primals_spec = tree_flatten(diff_primals)
101-
flat_primals_out, primals_out_spec = tree_flatten(primals_out)
102-
103-
for primal_out in flat_primals_out:
104-
assert isinstance(primal_out, torch.Tensor)
105-
if primal_out.is_floating_point() or primal_out.is_complex():
106-
continue
107-
raise RuntimeError("vjp(f, ...): All outputs of f must be "
108-
"floating-point or complex Tensors, got Tensor "
109-
f"with dtype {primal_out.dtype}")
110-
111-
def wrapper(cotangents, retain_graph=True, create_graph=True):
139+
# See NOTE [grad and vjp interaction with no_grad]
140+
with torch.enable_grad():
141+
primals = _wrap_all_tensors(primals, level)
142+
diff_primals = _create_differentiable(primals, level)
143+
primals_out = f(*diff_primals)
144+
145+
results = _undo_create_differentiable(primals_out, level)
146+
flat_diff_primals, primals_spec = tree_flatten(diff_primals)
147+
flat_primals_out, primals_out_spec = tree_flatten(primals_out)
148+
149+
for primal_out in flat_primals_out:
150+
assert isinstance(primal_out, torch.Tensor)
151+
if primal_out.is_floating_point() or primal_out.is_complex():
152+
continue
153+
raise RuntimeError("vjp(f, ...): All outputs of f must be "
154+
"floating-point or complex Tensors, got Tensor "
155+
f"with dtype {primal_out.dtype}")
156+
157+
def wrapper(cotangents, retain_graph=True, create_graph=None):
158+
if create_graph is None:
159+
create_graph = torch.is_grad_enabled()
112160
flat_cotangents, cotangents_spec = tree_flatten(cotangents)
113161
if primals_out_spec != cotangents_spec:
114162
raise RuntimeError(
@@ -236,30 +284,32 @@ def wrapper(*args, **kwargs):
236284
level = _grad_increment_nesting()
237285
output, aux, grad_input = None, None, None
238286
try:
239-
args = _wrap_all_tensors(args, level)
240-
kwargs = _wrap_all_tensors(kwargs, level)
241-
diff_args = _slice_argnums(args, argnums)
242-
tree_map_(partial(_create_differentiable, level=level), diff_args)
243-
244-
output = f(*args, **kwargs)
245-
if has_aux:
246-
output, aux = output
247-
248-
if not isinstance(output, torch.Tensor):
249-
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args)'
250-
f'to return a Tensor, got {type(output)}')
251-
if output.dim() != 0:
252-
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args)'
253-
'to return a scalar Tensor, got tensor with '
254-
f'{output.dim()} dims. Maybe you wanted to'
255-
'use the vjp or jacrev APIs instead?')
256-
257-
flat_diff_args, spec = tree_flatten(diff_args)
258-
259-
# NB: need create_graph so that backward pass isn't run in no_grad mode
260-
flat_outputs = _as_tuple(output)
261-
flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True)
262-
grad_input = tree_unflatten(flat_grad_input, spec)
287+
# See NOTE [grad and vjp interaction with no_grad]
288+
with torch.enable_grad():
289+
args = _wrap_all_tensors(args, level)
290+
kwargs = _wrap_all_tensors(kwargs, level)
291+
diff_args = _slice_argnums(args, argnums)
292+
tree_map_(partial(_create_differentiable, level=level), diff_args)
293+
294+
output = f(*args, **kwargs)
295+
if has_aux:
296+
output, aux = output
297+
298+
if not isinstance(output, torch.Tensor):
299+
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args)'
300+
f'to return a Tensor, got {type(output)}')
301+
if output.dim() != 0:
302+
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args)'
303+
'to return a scalar Tensor, got tensor with '
304+
f'{output.dim()} dims. Maybe you wanted to'
305+
'use the vjp or jacrev APIs instead?')
306+
307+
flat_diff_args, spec = tree_flatten(diff_args)
308+
309+
# NB: need create_graph so that backward pass isn't run in no_grad mode
310+
flat_outputs = _as_tuple(output)
311+
flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True)
312+
grad_input = tree_unflatten(flat_grad_input, spec)
263313

264314
finally:
265315
if grad_input is not None:

functorch/csrc/DynamicLayer.cpp

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class DynamicLayerStackHolder : public c10::DebugInfoBase {
3333
DynamicLayerStackHolder() {}
3434
virtual ~DynamicLayerStackHolder() {}
3535

36-
std::vector<DynamicLayer> dynamicLayerStack = { DynamicLayer(DispatchKey::Autograd, 1) };
36+
std::vector<DynamicLayer> dynamicLayerStack = { DynamicLayer(DispatchKey::Autograd, 1, nullopt, true) };
3737
};
3838

3939
thread_local std::shared_ptr<DynamicLayerStackHolder> kDynamicLayerStack;
@@ -117,13 +117,16 @@ static int64_t pushDynamicLayer(DynamicLayer&& dynamic_layer) {
117117
return layerId;
118118
}
119119

120-
static int64_t pushDynamicLayer(DispatchKey key, optional<int64_t> batch_size = nullopt) {
120+
static int64_t pushDynamicLayer(
121+
DispatchKey key,
122+
optional<int64_t> batch_size = nullopt,
123+
optional<bool> prev_grad_mode = nullopt) {
121124
auto& dynamicLayerStack = dynamicLayerStackAccessor();
122125
TORCH_INTERNAL_ASSERT(key != DispatchKey::Undefined);
123126
TORCH_INTERNAL_ASSERT(key != DispatchKey::Batched);
124127

125128
auto layerId = 1 + dynamicLayerStack.size();
126-
dynamicLayerStack.emplace_back(key, layerId, batch_size);
129+
dynamicLayerStack.emplace_back(key, layerId, batch_size, prev_grad_mode);
127130

128131
if (layerId == 2) {
129132
// std::cout << "DynamicLayer on" << std::endl;
@@ -134,10 +137,16 @@ static int64_t pushDynamicLayer(DispatchKey key, optional<int64_t> batch_size =
134137
return layerId;
135138
}
136139

137-
int64_t initAndPushDynamicLayer(DispatchKey key, optional<int64_t> batch_size) {
138-
auto layerId = pushDynamicLayer(key, batch_size);
140+
int64_t initAndPushDynamicLayer(
141+
DispatchKey key,
142+
optional<int64_t> batch_size,
143+
optional<bool> prev_grad_mode) {
144+
auto layerId = pushDynamicLayer(key, batch_size, prev_grad_mode);
139145
auto& data = getGlobalDynmetaData();
140146
TORCH_INTERNAL_ASSERT(data.find(layerId) == data.end());
147+
if (key == DispatchKey::Autograd) {
148+
TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
149+
}
141150
data[layerId] = std::make_shared<bool>(true);
142151
return layerId;
143152
}
@@ -374,7 +383,6 @@ struct WithoutTop {
374383
pushDynamicLayer(std::move(layer_));
375384
}
376385

377-
bool prev_grad_enabled_;
378386
DynamicLayer layer_;
379387
};
380388

@@ -394,6 +402,11 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
394402
auto cur_level = getDynamicLayerStack().back().layerId();
395403
auto cur_key = getDynamicLayerStack().back().key();
396404

405+
optional<bool> prev_grad_mode = getDynamicLayerStack().back().prevGradMode();
406+
if (cur_key == DispatchKey::Autograd) {
407+
TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
408+
}
409+
397410
auto unwrap = [&](const Tensor& tensor) {
398411
if (!tensor.defined()) {
399412
return tensor;
@@ -457,7 +470,13 @@ void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack*
457470
c10::impl::tls_set_dispatch_key_included(kDynamicLayerBackModeKey, true);
458471

459472
// Re-dispatch
460-
op.callBoxed(stack);
473+
if (cur_key == DispatchKey::Autograd && *prev_grad_mode == false) {
474+
// See NOTE [grad and vjp interaction with no_grad]
475+
c10::AutoGradMode guard(*prev_grad_mode);
476+
op.callBoxed(stack);
477+
} else {
478+
op.callBoxed(stack);
479+
}
461480

462481
// Step 4, 5, 6
463482
if (cur_key == DispatchKey::Autograd) {

functorch/csrc/DynamicLayer.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,43 @@ namespace at {
1717
namespace functorch {
1818

1919
struct TORCH_API DynamicLayer {
20-
DynamicLayer(DispatchKey key, int64_t layerId, optional<int64_t> batchSize = nullopt): key_(key), layerId_(layerId), batchSize_(batchSize) {}
20+
DynamicLayer(
21+
DispatchKey key,
22+
int64_t layerId,
23+
optional<int64_t> batchSize = nullopt,
24+
optional<bool> prev_grad_mode = nullopt):
25+
key_(key), layerId_(layerId), batchSize_(batchSize), prevGradMode_(prev_grad_mode)
26+
{
27+
if (key_ == DispatchKey::Autograd) {
28+
TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
29+
}
30+
}
2131

2232
DispatchKey key() const { return key_; }
2333
int64_t layerId() const { return layerId_; }
34+
// Only valid for vmap
2435
int64_t batchSize() const {
2536
TORCH_INTERNAL_ASSERT(batchSize_);
2637
return *batchSize_;
2738
}
39+
// only valid for grad-based transforms
40+
optional<bool> prevGradMode() const {
41+
return prevGradMode_;
42+
}
2843
private:
2944
DispatchKey key_;
3045
int64_t layerId_;
46+
47+
// Honestly these should be a union or some extendable metadata class.
48+
// Not doing that for now because I don't think we'll use this mechanism for very long.
3149
optional<int64_t> batchSize_;
50+
optional<bool> prevGradMode_;
3251
};
3352

34-
TORCH_API int64_t initAndPushDynamicLayer(DispatchKey key, optional<int64_t> batch_size = nullopt);
53+
TORCH_API int64_t initAndPushDynamicLayer(
54+
DispatchKey key,
55+
optional<int64_t> batch_size = nullopt,
56+
optional<bool> prev_grad_mode = nullopt);
3557
TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata();
3658
TORCH_API c10::optional<DynamicLayer> maybeCurrentDynamicLayer();
3759
TORCH_API const std::vector<DynamicLayer>& getDynamicLayerStack();

functorch/csrc/init.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,9 @@ bool dump_tensor(const Tensor& self) {
158158
}
159159

160160
int64_t _grad_increment_nesting() {
161-
return initAndPushDynamicLayer(at::DispatchKey::Autograd);
161+
// See NOTE [grad and vjp interaction with no_grad]
162+
bool prev_grad_mode = c10::GradMode::is_enabled();
163+
return initAndPushDynamicLayer(at::DispatchKey::Autograd, nullopt, prev_grad_mode);
162164
}
163165

164166
int64_t _grad_decrement_nesting() {

0 commit comments

Comments
 (0)