Skip to content

Commit a80eb84

Browse files
xmfanpytorchmergebot
authored andcommitted
[ca] support higher order gradients (create_graph=True) (pytorch#153222)
Adds create_graph support if you don't compile or compile only with torch.compile(backend="eager"). Using a backend that uses AOTDispatch produces a post-dispatch AOT backward, where its double backward will be silently incorrect if the forward trace involved any ops that are not composite implicit. Pull Request resolved: pytorch#153222 Approved by: https://github.com/jansel ghstack dependencies: pytorch#153193
1 parent 37efaf4 commit a80eb84

File tree

6 files changed

+90
-32
lines changed

6 files changed

+90
-32
lines changed

test/inductor/test_compiled_autograd.py

Lines changed: 68 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4108,6 +4108,57 @@ def backward(ctx, gO):
41084108
):
41094109
fn()
41104110

4111+
def test_higher_order_gradients(self):
4112+
def f(x):
4113+
return x**3
4114+
4115+
def fn(fwd_compiler, ca_compiler):
4116+
torch.manual_seed(123)
4117+
x = torch.tensor(2.0, requires_grad=True)
4118+
first, second, third, fourth = None, None, None, None
4119+
try:
4120+
with compiled_autograd._enable(ca_compiler):
4121+
first = torch.autograd.grad(
4122+
fwd_compiler(f)(x), x, create_graph=True
4123+
)[0]
4124+
second = torch.autograd.grad(first, x, create_graph=True)[0]
4125+
third = torch.autograd.grad(second, x, create_graph=True)[0]
4126+
fourth = torch.autograd.grad(third, x, create_graph=True)[0]
4127+
except RuntimeError as e:
4128+
assert "does not currently support higher order gradients" in str(e)
4129+
return (first, second, third, fourth)
4130+
4131+
return (first, second, third, fourth)
4132+
4133+
def eager():
4134+
return torch.compile(backend="eager")
4135+
4136+
def aot_eager():
4137+
return torch.compile(backend="aot_eager")
4138+
4139+
# Without AOTAutograd, no problem
4140+
first, second, third, fourth = fn(eager(), eager())
4141+
self.assertEqual(counters["compiled_autograd"]["captures"], 4)
4142+
self.assertEqual(first, 12) # 3x^2
4143+
self.assertEqual(second, 12) # 6x
4144+
self.assertEqual(third, 6) # 6
4145+
self.assertEqual(fourth, 0)
4146+
# and should cache hit
4147+
counters.clear()
4148+
_ = fn(eager(), eager())
4149+
self.assertEqual(counters["compiled_autograd"]["captures"], 0)
4150+
torch._dynamo.reset()
4151+
4152+
# With AOTAutograd, can't create_graph
4153+
first, second, third, fourth = fn(aot_eager(), aot_eager())
4154+
self.assertIsNone(second)
4155+
4156+
first, second, third, fourth = fn(aot_eager(), eager())
4157+
self.assertIsNone(second)
4158+
4159+
first, second, third, fourth = fn(eager(), aot_eager())
4160+
self.assertIsNone(third)
4161+
41114162

41124163
def load_test_module(name):
41134164
testdir = Path(__file__).absolute().parent.parent
@@ -4227,6 +4278,10 @@ def wrap_test_class(orig_cls):
42274278
"test_prehook_ordering", # retains_grad_hooks
42284279
"test_will_engine_execute_node", # retains_grad_hooks
42294280
"test_backward_to_node", # retains_grad_hooks
4281+
"test_backward_with_nonleaf_inputs", # retains_grad_hook on non-leaf input
4282+
"test_create_graph_and_full_backward_hook_cycle", # _pack_with_none
4283+
"test_full_backward_hook_double_backward", # _pack_with_none
4284+
"test_grad_mode_restored_reentrant", # assertTrue
42304285
}
42314286

42324287
test_contexts = {
@@ -4246,42 +4301,20 @@ def wrap_test_class(orig_cls):
42464301

42474302
known_failing_tests = {
42484303
# Category: Compiled autograd
4249-
"test_grad_mode_restored_reentrant", # create_graph
42504304
"test_reentrant_with_callbacks_both_depths", # queue_callback
42514305
"test_reentrant_with_callbacks_depth_0", # queue_callback
42524306
"test_reentrant_with_callbacks_depth_1", # queue_callback
42534307
"test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook
42544308
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
42554309
"test_current_node", # TorchDispatchMode not yet implemented for compiled autograd
42564310
"test_post_accumulate_grad_hook_ordering", # accuracy error
4257-
"test_accumulate_grad", # create_graph
4258-
"test_anomaly_assign_parent_cleanup", # create_graph
4259-
"test_backward_create_graph_warns", # create_graph
4260-
"test_backward_with_nonleaf_inputs", # create_graph
4261-
"test_create_graph_and_full_backward_hook_cycle", # create_graph
42624311
"test_current_graph_task_id", # autograd state already cleared once dynamo is called
4263-
"test_custom_autograd_repeated_grad_grad", # create_graph
42644312
"test_custom_function_forward_mode_forward_is_no_op", # forward AD
42654313
"test_custom_function_forward_mode_inplace_checks", # forward AD
42664314
"test_custom_function_forward_mode_view_checks", # forward AD
42674315
"test_custom_function_forward_mode_wrong_formula", # forward AD
4268-
"test_default_saved_tensors_hooks_double_backward", # create_graph
42694316
"test_node_post_hook_registered_during_unpack_hook", # 'NoneType' object has no attribute 'register_hook'
4270-
"test_full_backward_hook_double_backward", # create_graph
4271-
"test_function", # create_graph
4272-
"test_grad", # create_graph
4273-
"test_grad_materialize_grads", # create_graph
4274-
"test_grad_nonleaf", # create_graph
4275-
"test_grad_nonleaf_many_outputs", # create_graph
4276-
"test_hessian_vector", # create_graph
4277-
"test_inplace_on_view_backward", # create_graph
42784317
"test_multi_grad_any_hooks", # register_multi_grad_hook
4279-
"test_nested_anomaly_detect_nan", # create_graph
4280-
"test_nested_anomaly_printstack_cleanup", # create_graph
4281-
"test_once_differentiable", # create_graph
4282-
"test_saved_variable_packing_unpacking_saved_original_with_hooks", # create_graph
4283-
"test_select_sum", # create_graph, also needs graph breaks
4284-
"test_custom_autograd_no_early_free", # create_graph
42854318
"test_custom_function_error", # vjp
42864319
"test_custom_function_save_for_forward", # vjp
42874320
"test_dont_materialize_grads", # undefined grad
@@ -4290,10 +4323,16 @@ def wrap_test_class(orig_cls):
42904323
"test_node_ordering_when_none_returned", # torch._dynamo.exc.Unsupported: TypeError <built-in method clone
42914324
"test_save_output_nr", # output_nr grad passed as None
42924325
"test_setup_context_when_forward_has_default_args", # autograd.Function with class methods
4293-
"test_lobpcg", # create_graph
42944326
# IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
42954327
"test_grad_nonleaf_register_hook",
42964328
"test_backward_twice_without_saved_values", # https://github.com/pytorch/pytorch/issues/129938
4329+
# Category: Higher Order Gradients
4330+
"test_default_saved_tensors_hooks_double_backward", # wrong when pack hook returns non-leaf
4331+
"test_saved_variable_packing_unpacking_saved_original_with_hooks", # wrong when pack hook returns non-leaf
4332+
"test_nested_anomaly_detect_nan", # nested anomaly
4333+
"test_select_sum", # batched gradients
4334+
"test_custom_autograd_no_early_free", # batched gradients
4335+
"test_lobpcg", # NaNs
42974336
# Category: Dynamo (pass when directly running CA graph)
42984337
"test_accumulate_grad_tensor_reference", # Out of bounds: frame_state_entry.stride[i] is None
42994338
"test_custom_function_exception", # torch.no_grad(), torch._dynamo.exc.Unsupported: missing: WITH_EXCEPT_START
@@ -4339,8 +4378,14 @@ def wrap_test_class(orig_cls):
43394378
"test_anomaly_mode_no_check_nan", # different error messages
43404379
"test_anomaly_grad_warnings", # different error messages
43414380
"test_anomaly_detect_nan", # fake tensor errors on NaN
4381+
"test_once_differentiable", # different node name: CompiledFunctionBackward
4382+
"test_function", # different node name: CompiledFunctionBackward
4383+
"test_inplace_on_view_backward", # different node name: CompiledFunctionBackward
4384+
"test_nested_anomaly_printstack_cleanup", # anomaly NaN error message different
43424385
# Uncategorized
43434386
"test_not_implemented_grad", # Dynamo changes the types of exceptions
4387+
"test_grad", # AOT backward higher order gradients
4388+
"test_grad_materialize_grads", # AOT backward higher order gradients
43444389
}
43454390

43464391
if not HAS_CUDA:

test/test_autograd.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,7 @@ def compute_grad(create_graph):
830830
x_grad, x_grad_clone = compute_grad(create_graph=False)
831831
self.assertEqual(x_grad, x_grad_clone * 2)
832832

833-
# Accumulate out-of-place when create_graph is False
833+
# Accumulate out-of-place when create_graph is True
834834
x_grad, x_grad_clone = compute_grad(create_graph=True)
835835
self.assertEqual(x_grad, x_grad_clone)
836836

@@ -9376,10 +9376,14 @@ def forward(self, x):
93769376
with set_warn_always_context(True):
93779377
with warnings.catch_warnings(record=True) as w:
93789378
tmp.exp().sum().backward(create_graph=True)
9379-
self.assertTrue(len(w) == 1)
9380-
self.assertTrue(
9381-
"Using backward() with create_graph=True" in str(w[0].message)
9382-
)
9379+
self.assertTrue(w)
9380+
found = 0
9381+
for warning in w:
9382+
if "Using backward() with create_graph=True" in str(
9383+
warning.message
9384+
):
9385+
found += 1
9386+
self.assertEqual(found, 1)
93839387

93849388
# Remove the backward + create_graph=True cycle
93859389
a.grad = None

torch/_dynamo/compiled_autograd.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,13 @@ def proxy_call_aot_backward(
423423
aot_id = CompiledFunction._aot_id
424424
del CompiledFunction
425425

426+
if torch.is_grad_enabled():
427+
for output_alias_info in metadata.output_info:
428+
if output_alias_info.requires_grad:
429+
raise RuntimeError(
430+
"torch.compile does not currently support higher order gradients."
431+
)
432+
426433
@torch._dynamo.allow_in_graph # type: ignore[misc]
427434
def call_aot_bwd_prologue(ctx_saved_tensors, ctx_symints, *flat_args):
428435
out = torch._functorch._aot_autograd.runtime_wrappers._backward_prologue_functional(

torch/_dynamo/polyfills/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,15 @@ def radians(x):
7575

7676

7777
def accumulate_grad(x, new_grad):
78+
# polyfills according to the Gradient Layout Contract
7879
if new_grad is None:
7980
return
8081
new_grad_strided = torch.empty_like(x)
8182
new_grad_strided.copy_(new_grad)
8283
if x.grad is None:
8384
x.grad = new_grad_strided
85+
elif torch.is_grad_enabled():
86+
x.grad = x.grad + new_grad_strided
8487
else:
8588
x.grad.add_(new_grad_strided)
8689

torch/csrc/autograd/engine.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,8 +1326,8 @@ auto Engine::execute(
13261326

13271327
auto graph_task = std::make_shared<GraphTask>(
13281328
/* keep_graph */ keep_graph,
1329-
/* create_graph */ create_graph,
1330-
/* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
1329+
/* grad_mode */ create_graph,
1330+
/* reentrant_depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
13311331
/* cpu_ready_queue */ local_ready_queue,
13321332
/* graph_roots */ std::move(temp_roots));
13331333

@@ -1348,8 +1348,6 @@ auto Engine::execute(
13481348

13491349
if (compiled_autograd != nullptr) {
13501350
// see [Note: Compiled Autograd]
1351-
TORCH_CHECK(
1352-
!create_graph, "compiled_autograd does not support create_graph");
13531351
_thread_check.release();
13541352
GraphTaskGuard guard(graph_task);
13551353
CheckpointValidGuard cpvguard(graph_task);

torch/csrc/autograd/functions/accumulate_grad.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ void AccumulateGrad::compiled_args(CompiledNodeArgs& args) const {
7171
args.collect(variable);
7272
args.collect(variable.grad());
7373
}
74+
args.collect(GradMode::is_enabled());
7475
const auto& hook = tensor_post_acc_grad_hooks();
7576
if (hook != nullptr) {
7677
hook->compiled_args(args);

0 commit comments

Comments
 (0)