@@ -88,27 +88,75 @@ def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, creat
88
88
for gi , inp in zip (grad_inputs , inputs ))
89
89
return grad_inputs
90
90
91
+ # NOTE [grad and vjp interaction with no_grad]
92
+ #
93
+ # def f(x):
94
+ # with torch.no_grad():
95
+ # c = x ** 2
96
+ # return x - c
97
+ #
98
+ # The thing to consider is if enable_grad is on/off before grad gets called.
99
+ #
100
+ # Case 1: enable_grad is on.
101
+ # grad(f)(x)
102
+ # In this case, `grad` should respect the inner torch.no_grad.
103
+ #
104
+ # Case 2: enable_grad is off
105
+ # with torch.no_grad():
106
+ # grad(f)(x)
107
+ # In this case, `grad` should respect the inner torch.no_grad, but not the
108
+ # outer one. This is because `grad` is a "function transform": its result
109
+ # should not depend on the result of a context manager outside of `f`.
110
+ #
111
+ # This gives us the following desired behavior:
112
+ # - (nested) grad transforms must obey torch.no_grad inside them
113
+ # - (nested) grad transforms should not obey torch.no_grad outside them
114
+ #
115
+ # To achieve this behavior, upon entering grad/vjp:
116
+ # - we save the current ("previous") is_grad_enabled (*)
117
+ # - we unconditionally enable grad.
118
+ #
119
+ # Inside DynamicLayerBackFallback, when we're temporarily popping `grad` layer
120
+ # off the stack:
121
+ # - if grad_mode is disabled, then we do nothing. (there is a torch.no_grad
122
+ # active, all subsequent grad transforms must obey it).
123
+ # - if grad_mode is enabled, and the previous is_grad_enabled (*) is False,
124
+ # then we temporarily restore the previous `is_grad_enabled`. This is
125
+ # because we're crossing the boundary from a `grad` outside the
126
+ # no_grad to a `grad` inside the no_grad.
127
+ #
128
+ # NB: vjp has some interesting behavior because the vjp's callable can be called
129
+ # under a different grad_mode than the forward computation...
130
+ #
131
+ # TODO: forward-mode AD: does it also respect no_grad? What does that mean
132
+ # for our jvp transform?
133
+
134
+
91
135
# How do we increment and decrement the nesting? I don't think we can.
92
136
def vjp (f , * primals ):
93
137
level = _grad_increment_nesting ()
94
138
try :
95
- primals = _wrap_all_tensors (primals , level )
96
- diff_primals = _create_differentiable (primals , level )
97
- primals_out = f (* diff_primals )
98
-
99
- results = _undo_create_differentiable (primals_out , level )
100
- flat_diff_primals , primals_spec = tree_flatten (diff_primals )
101
- flat_primals_out , primals_out_spec = tree_flatten (primals_out )
102
-
103
- for primal_out in flat_primals_out :
104
- assert isinstance (primal_out , torch .Tensor )
105
- if primal_out .is_floating_point () or primal_out .is_complex ():
106
- continue
107
- raise RuntimeError ("vjp(f, ...): All outputs of f must be "
108
- "floating-point or complex Tensors, got Tensor "
109
- f"with dtype { primal_out .dtype } " )
110
-
111
- def wrapper (cotangents , retain_graph = True , create_graph = True ):
139
+ # See NOTE [grad and vjp interaction with no_grad]
140
+ with torch .enable_grad ():
141
+ primals = _wrap_all_tensors (primals , level )
142
+ diff_primals = _create_differentiable (primals , level )
143
+ primals_out = f (* diff_primals )
144
+
145
+ results = _undo_create_differentiable (primals_out , level )
146
+ flat_diff_primals , primals_spec = tree_flatten (diff_primals )
147
+ flat_primals_out , primals_out_spec = tree_flatten (primals_out )
148
+
149
+ for primal_out in flat_primals_out :
150
+ assert isinstance (primal_out , torch .Tensor )
151
+ if primal_out .is_floating_point () or primal_out .is_complex ():
152
+ continue
153
+ raise RuntimeError ("vjp(f, ...): All outputs of f must be "
154
+ "floating-point or complex Tensors, got Tensor "
155
+ f"with dtype { primal_out .dtype } " )
156
+
157
+ def wrapper (cotangents , retain_graph = True , create_graph = None ):
158
+ if create_graph is None :
159
+ create_graph = torch .is_grad_enabled ()
112
160
flat_cotangents , cotangents_spec = tree_flatten (cotangents )
113
161
if primals_out_spec != cotangents_spec :
114
162
raise RuntimeError (
@@ -236,30 +284,32 @@ def wrapper(*args, **kwargs):
236
284
level = _grad_increment_nesting ()
237
285
output , aux , grad_input = None , None , None
238
286
try :
239
- args = _wrap_all_tensors (args , level )
240
- kwargs = _wrap_all_tensors (kwargs , level )
241
- diff_args = _slice_argnums (args , argnums )
242
- tree_map_ (partial (_create_differentiable , level = level ), diff_args )
243
-
244
- output = f (* args , ** kwargs )
245
- if has_aux :
246
- output , aux = output
247
-
248
- if not isinstance (output , torch .Tensor ):
249
- raise RuntimeError ('grad_and_value(f)(*args): Expected f(*args)'
250
- f'to return a Tensor, got { type (output )} ' )
251
- if output .dim () != 0 :
252
- raise RuntimeError ('grad_and_value(f)(*args): Expected f(*args)'
253
- 'to return a scalar Tensor, got tensor with '
254
- f'{ output .dim ()} dims. Maybe you wanted to'
255
- 'use the vjp or jacrev APIs instead?' )
256
-
257
- flat_diff_args , spec = tree_flatten (diff_args )
258
-
259
- # NB: need create_graph so that backward pass isn't run in no_grad mode
260
- flat_outputs = _as_tuple (output )
261
- flat_grad_input = _autograd_grad (flat_outputs , flat_diff_args , create_graph = True )
262
- grad_input = tree_unflatten (flat_grad_input , spec )
287
+ # See NOTE [grad and vjp interaction with no_grad]
288
+ with torch .enable_grad ():
289
+ args = _wrap_all_tensors (args , level )
290
+ kwargs = _wrap_all_tensors (kwargs , level )
291
+ diff_args = _slice_argnums (args , argnums )
292
+ tree_map_ (partial (_create_differentiable , level = level ), diff_args )
293
+
294
+ output = f (* args , ** kwargs )
295
+ if has_aux :
296
+ output , aux = output
297
+
298
+ if not isinstance (output , torch .Tensor ):
299
+ raise RuntimeError ('grad_and_value(f)(*args): Expected f(*args)'
300
+ f'to return a Tensor, got { type (output )} ' )
301
+ if output .dim () != 0 :
302
+ raise RuntimeError ('grad_and_value(f)(*args): Expected f(*args)'
303
+ 'to return a scalar Tensor, got tensor with '
304
+ f'{ output .dim ()} dims. Maybe you wanted to'
305
+ 'use the vjp or jacrev APIs instead?' )
306
+
307
+ flat_diff_args , spec = tree_flatten (diff_args )
308
+
309
+ # NB: need create_graph so that backward pass isn't run in no_grad mode
310
+ flat_outputs = _as_tuple (output )
311
+ flat_grad_input = _autograd_grad (flat_outputs , flat_diff_args , create_graph = True )
312
+ grad_input = tree_unflatten (flat_grad_input , spec )
263
313
264
314
finally :
265
315
if grad_input is not None :
0 commit comments