Skip to content

Commit a28dcdb

Browse files
Revert "[aot][ca] save bw_module in AOTAutogradCache (pytorch#151860)"
This reverts commit 613bd46. Reverted pytorch#151860 on behalf of https://github.com/huydhn due to Chatting with @xmfan and decide to revert and reland this instead ([comment](pytorch#151860 (comment)))
1 parent f6db749 commit a28dcdb

File tree

6 files changed

+22
-163
lines changed

6 files changed

+22
-163
lines changed

test/dynamo/test_aot_autograd_cache.py

Lines changed: 0 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -66,118 +66,6 @@ def _clear_dynamo_and_codecache(self):
6666
torch._dynamo.reset()
6767
torch._inductor.codecache.PyCodeCache.cache_clear(purge=True)
6868

69-
@functorch_config.patch({"enable_autograd_cache": True})
70-
@inductor_config.patch(
71-
{
72-
"fx_graph_cache": True,
73-
"fx_graph_remote_cache": False,
74-
"autotune_local_cache": True,
75-
}
76-
)
77-
def test_cache_lazy_backward_for_compiled_autograd(self):
78-
device = "cpu"
79-
dtype = torch.float32
80-
dynamic = True
81-
"""
82-
Verify that we can populate and hot load functions from the cache.
83-
"""
84-
if device == GPU_TYPE and not HAS_GPU:
85-
raise unittest.SkipTest(f"requires {GPU_TYPE}")
86-
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
87-
raise unittest.SkipTest("requires SM80 or later")
88-
89-
def fn(x, y):
90-
return x.sin() @ y
91-
92-
a = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True)
93-
b = torch.rand(100, 100, dtype=dtype, device=device, requires_grad=True)
94-
95-
# Record artifacts
96-
with fresh_inductor_cache():
97-
compiled_fn = torch.compile(fn, dynamic=dynamic)
98-
99-
# A first call should miss in the cache.
100-
eager_result = fn(a, b)
101-
expected_grads = torch.autograd.grad(eager_result.sum(), inputs=(a, b))
102-
compiled_result = compiled_fn(a, b)
103-
with torch._dynamo.compiled_autograd._enable(
104-
torch.compile(dynamic=dynamic)
105-
):
106-
actual_grads = torch.autograd.grad(compiled_result.sum(), inputs=(a, b))
107-
if hasattr(a, "_dynamo_weak_dynamic_indices"):
108-
del a._dynamo_weak_dynamic_indices
109-
self.assertEqual(eager_result, compiled_result)
110-
self.assertEqual(expected_grads[0], actual_grads[0])
111-
self.assertEqual(expected_grads[1], actual_grads[1])
112-
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 3)
113-
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
114-
self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0)
115-
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
116-
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
117-
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
118-
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
119-
120-
artifacts = torch.compiler.save_cache_artifacts()
121-
122-
self.assertIsNotNone(artifacts)
123-
124-
artifact_bytes, cache_info = artifacts
125-
126-
autotune_expect = 2 if device == GPU_TYPE else 0
127-
128-
self.assertEqual(len(cache_info.inductor_artifacts), 3)
129-
self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect)
130-
self.assertEqual(len(cache_info.aot_autograd_artifacts), 1)
131-
self.assertEqual(len(cache_info.pgo_artifacts), 0)
132-
133-
self._clear_all_caches()
134-
135-
# Clean triton kernels
136-
shutil.rmtree(os.path.join(cache_dir(), "triton"), ignore_errors=True)
137-
138-
# Hot load and hit, should not recompile
139-
with fresh_inductor_cache():
140-
cache_info = torch.compiler.load_cache_artifacts(artifact_bytes)
141-
142-
self.assertEqual(len(cache_info.inductor_artifacts), 3)
143-
self.assertEqual(len(cache_info.autotune_artifacts), autotune_expect)
144-
self.assertEqual(len(cache_info.aot_autograd_artifacts), 1)
145-
self.assertEqual(len(cache_info.pgo_artifacts), 0)
146-
147-
for i in range(3):
148-
counters.clear()
149-
eager_result = fn(a, b)
150-
expected_grads = torch.autograd.grad(eager_result.sum(), inputs=(a, b))
151-
compiled_result = compiled_fn(a, b)
152-
with torch._dynamo.compiled_autograd._enable(
153-
torch.compile(dynamic=dynamic)
154-
):
155-
actual_grads = torch.autograd.grad(
156-
compiled_result.sum(), inputs=(a, b)
157-
)
158-
if hasattr(a, "_dynamo_weak_dynamic_indices"):
159-
del a._dynamo_weak_dynamic_indices
160-
self.assertEqual(eager_result, compiled_result)
161-
self.assertEqual(expected_grads[0], actual_grads[0])
162-
self.assertEqual(expected_grads[1], actual_grads[1])
163-
164-
if i == 0:
165-
# initial compile
166-
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0)
167-
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 3)
168-
self.assertEqual(
169-
counters["inductor"]["fxgraph_lookup_write_file"], 3
170-
)
171-
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 0)
172-
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
173-
self.assertEqual(
174-
counters["aot_autograd"]["autograd_cache_saved"], 0
175-
)
176-
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
177-
else:
178-
# no recompiles
179-
self.assertFalse(counters)
180-
18169
@requires_triton()
18270
@functorch_config.patch({"enable_autograd_cache": True})
18371
@inductor_config.patch(

test/dynamo/test_structured_trace.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -457,13 +457,12 @@ def test_example_training_fn(self):
457457
{"inductor_post_grad_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
458458
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
459459
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
460-
{"dynamo_start": {"stack": "STACK"}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
461460
{"bwd_compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1}
462-
{"dynamo_start": {"stack": "STACK"}, "frame_id": 5, "frame_compile_id": 0, "attempt": 0}
463-
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 5, "frame_compile_id": 0, "attempt": 0}
464-
{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 5, "frame_compile_id": 0, "attempt": 0}
465-
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['output']"}, "frame_id": 5, "frame_compile_id": 0, "attempt": 0}
466-
{"compilation_metrics": "METRICS", "frame_id": 5, "frame_compile_id": 0, "attempt": 0}
461+
{"dynamo_start": {"stack": "STACK"}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
462+
{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
463+
{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
464+
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['output']"}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
465+
{"compilation_metrics": "METRICS", "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
467466
""", # noqa: B950
468467
)
469468

torch/_dynamo/compiled_autograd.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -409,11 +409,7 @@ def proxy_call_aot_backward(
409409
metadata = CompiledFunction.metadata
410410
maybe_subclass_metadata = CompiledFunction.maybe_subclass_metadata
411411
aot_id = CompiledFunction._aot_id
412-
bw_module = ctx._bw_module
413-
aot_symints = ctx.symints
414-
symints = ctx._get_compiled_autograd_symints()
415412
del CompiledFunction
416-
del ctx
417413

418414
@torch._dynamo.allow_in_graph # type: ignore[misc]
419415
def call_aot_bwd_prologue(ctx_saved_tensors, ctx_symints, *flat_args):
@@ -455,12 +451,13 @@ def num_inputs(graph):
455451

456452
# set up the proxy inputs to ctx._bw_module
457453
# the calling convention is: [*symints, *args (primals and tangents), backward_state]
458-
num_args = num_inputs(bw_module.graph)
454+
num_args = num_inputs(ctx._bw_module.graph)
459455
pall_args = [
460456
pgrads[i] for i in range(num_args - int(pbackward_state is not None))
461457
]
462458
# replace the symints with our symints
463-
assert len(symints) == len(aot_symints)
459+
symints = ctx._get_compiled_autograd_symints()
460+
assert len(symints) == len(ctx.symints)
464461
psymints = [self.to_proxy(e) for e in symints]
465462
pall_args[: len(symints)] = psymints
466463
# Add backward_state
@@ -484,7 +481,7 @@ def make_unique(node_name):
484481
# make it both informative and unique
485482
return f"aot{deduped_aot_id}_{node_name}"
486483

487-
for node in bw_module.graph.nodes:
484+
for node in ctx._bw_module.graph.nodes:
488485
if node.op == "placeholder":
489486
ph = pall_args[args_idx].node
490487
ph.name = make_unique(node.name)
@@ -501,7 +498,9 @@ def make_unique(node_name):
501498
elif node.op == "get_attr":
502499
name = node.target
503500
qualname = self.fx_tracer.get_fresh_qualname(name)
504-
setattr(self.fx_tracer.root, qualname, getattr(bw_module, name))
501+
setattr(
502+
self.fx_tracer.root, qualname, getattr(ctx._bw_module, name)
503+
)
505504
result = self.fx_tracer.create_node("get_attr", qualname, (), {})
506505
result.name = make_unique(node.name)
507506
value_remap[node] = result
@@ -1271,6 +1270,11 @@ def set_node_origin(
12711270
forward_cls = pyobj._forward_cls # type: ignore[attr-defined]
12721271
if hasattr(forward_cls, "_aot_id"):
12731272
# backward was created by AOT Dispatcher
1273+
if forward_cls._lazy_backward_info is None:
1274+
raise RuntimeError(
1275+
"""This compiled backward function was saved by AOTAutogradCache, which does not support
1276+
compiled autograd. Please turn off AOTAutogradCache using `TORCHINDUCTOR_AUTOGRAD_CACHE=0`."""
1277+
)
12741278
maybe_aot_id = forward_cls._aot_id
12751279
new_code = f"{node_name}{maybe_aot_id} (NodeCall {nodecall_index})"
12761280
raw_stack_trace = CapturedTraceback.extract().format()[-1]

torch/_functorch/_aot_autograd/autograd_cache.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
from .runtime_wrappers import (
5858
AOTDispatchAutograd,
5959
AOTDispatchSubclassWrapper,
60-
CachedAutogradLazyBackwardCompileInfo,
6160
CompilerWrapper,
6261
FunctionalizedRngRuntimeWrapper,
6362
post_compile,
@@ -552,9 +551,6 @@ class GenericAOTAutogradCacheEntry(Generic[TForward, TBackward]):
552551

553552
guards_expr: Optional[str]
554553

555-
# # Used by compiled autograd
556-
cached_lazy_backward_info: Optional[CachedAutogradLazyBackwardCompileInfo]
557-
558554
# Turn cache entry into the original callable
559555
def wrap_post_compile(
560556
self,
@@ -700,7 +696,7 @@ def wrap_post_compile(
700696
self.compiled_bw.backward_state_indices,
701697
disable_amp,
702698
self.indices_of_inps_to_detach,
703-
self.cached_lazy_backward_info,
699+
None, # lazy_backward_info
704700
aot_config,
705701
fw_metadata=self.runtime_metadata,
706702
try_save_cache_entry=None,

torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
AOTDispatchSubclassWrapper,
5656
AOTSyntheticBaseWrapper,
5757
AutogradLazyBackwardCompileInfo,
58-
CachedAutogradLazyBackwardCompileInfo,
5958
CompilerWrapper,
6059
DebugAssertWrapper,
6160
EffectTokensWrapper,
@@ -279,7 +278,6 @@ def aot_dispatch_base(
279278
backward_time_taken_ns=0,
280279
sanitized_aot_config=sanitize_aot_config(aot_config),
281280
guards_expr=guards_expr,
282-
cached_lazy_backward_info=None,
283281
)
284282
AOTAutogradCache.save(
285283
cache_info.cache_key, entry, remote=should_use_remote_autograd_cache()
@@ -1283,13 +1281,8 @@ def aot_dispatch_autograd(
12831281
# close over aot_config.cache_info, since aot_config never changes.
12841282
# But closing over random variables is confusing IMO, so I'm leaving it.
12851283
def try_save_cache_entry( # noqa: F811
1286-
compiled_bw_func, lazy_backward_info, _fw_metadata, aot_config
1284+
compiled_bw_func, _fw_metadata, aot_config
12871285
):
1288-
bw_module = lazy_backward_info.bw_module
1289-
bw_module.meta = {}
1290-
for node in bw_module.graph.nodes:
1291-
node.meta = {}
1292-
12931286
fw_key = getattr(compiled_fw_func, "_fx_graph_cache_key", None)
12941287
fw_debug_lines = getattr(
12951288
compiled_fw_func, "_fx_graph_cache_debug_lines", []
@@ -1335,18 +1328,13 @@ def try_save_cache_entry( # noqa: F811
13351328
backward_time_taken_ns,
13361329
sanitized_aot_config=sanitize_aot_config(aot_config),
13371330
guards_expr=guards_expr,
1338-
cached_lazy_backward_info=CachedAutogradLazyBackwardCompileInfo(
1339-
bw_module
1340-
),
13411331
)
13421332
remote = should_use_remote_autograd_cache()
13431333
AOTAutogradCache.save(cache_info.cache_key, entry, remote)
13441334

13451335
if compiled_bw_func is not None:
13461336
# If we already compiled it we can just run it right now without waiting
1347-
try_save_cache_entry(
1348-
compiled_bw_func, lazy_backward_info, fw_metadata, aot_config
1349-
)
1337+
try_save_cache_entry(compiled_bw_func, fw_metadata, aot_config)
13501338
try_save_cache_entry = None
13511339

13521340
compiled_fn = AOTDispatchAutograd.post_compile(

torch/_functorch/_aot_autograd/runtime_wrappers.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,20 +1484,12 @@ def make_hashable(arg):
14841484
# with compiled autograd. See: https://github.com/pytorch/pytorch/pull/149229#discussion_r2002122645.
14851485
@dataclass
14861486
class AutogradLazyBackwardCompileInfo:
1487-
bw_module: torch.fx.GraphModule
1487+
bw_module: Callable
14881488
placeholder_list: list[Any]
14891489
saved_context: Optional[TracingContext]
14901490
saved_compile_context: Optional[CompileContext]
14911491

14921492

1493-
# On an AOT Autograd cache hit, we already have a lowered backward, so there is usually
1494-
# no need to keep information around for a new lazy compilation. Except for compiled autograd,
1495-
# which wants to retrace this backward into a larger graph, and it needs the graph module to do so.
1496-
@dataclass
1497-
class CachedAutogradLazyBackwardCompileInfo:
1498-
bw_module: torch.fx.GraphModule # missing a couple of fields compared to AutogradLazyBackwardCompileInfo's bw_module
1499-
1500-
15011493
def _raise_if_functorch_active():
15021494
# not ideal but prevent the user from seeing a nasty traceback - See #138422
15031495
stack = torch._C._functorch.peek_interpreter_stack()
@@ -1917,11 +1909,7 @@ def post_compile(
19171909
backward_state_indices: list[int],
19181910
disable_amp: bool,
19191911
indices_of_inps_to_detach: list[int],
1920-
lazy_backward_info: Optional[
1921-
Union[
1922-
AutogradLazyBackwardCompileInfo, CachedAutogradLazyBackwardCompileInfo
1923-
]
1924-
],
1912+
lazy_backward_info: Optional[AutogradLazyBackwardCompileInfo],
19251913
aot_config: AOTConfig,
19261914
*,
19271915
fw_metadata: ViewAndMutationMeta, # runtime metadata
@@ -2229,9 +2217,6 @@ def _backward_impl(ctx, all_args):
22292217

22302218
if CompiledFunction.compiled_bw is None:
22312219
assert lazy_backward_info is not None
2232-
assert isinstance(
2233-
lazy_backward_info, AutogradLazyBackwardCompileInfo
2234-
)
22352220

22362221
if not saved_tensors_use_once:
22372222
fw_metadata.bw_donated_idxs = []
@@ -2276,7 +2261,6 @@ def _backward_impl(ctx, all_args):
22762261
if try_save_cache_entry is not None:
22772262
try_save_cache_entry(
22782263
CompiledFunction.compiled_bw,
2279-
lazy_backward_info,
22802264
fw_metadata,
22812265
aot_config,
22822266
)

0 commit comments

Comments
 (0)