Skip to content

Commit e3d3ba8

Browse files
yiming0416mori360
authored andcommitted
[graph_trainer] Log transformed graph to tlparse via trace_structured (pytorch#2619)
This PR adds `tlparse_log_graph_pass` that logs post-transform forward/backward graphs to tlparse, replacing `logger.debug(gm.print_readable(...))` calls.
1 parent 44104d9 commit e3d3ba8

File tree

3 files changed

+56
-16
lines changed

3 files changed

+56
-16
lines changed

torchtitan/experiments/graph_trainer/cudagraph.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,13 @@ def __init__(
145145
# (debug only) whether check static input tensor addresses during runtime
146146
self._should_check_address = should_check_address
147147

148+
self._gm = runnable if isinstance(runnable, torch.fx.GraphModule) else None
149+
150+
def print_readable(self, *args, **kwargs):
151+
"""Delegate to the inner GraphModule's print_readable."""
152+
assert self._gm is not None, "print_readable requires a GraphModule runnable"
153+
return self._gm.print_readable(*args, **kwargs)
154+
148155
def _copy_non_static_inputs(self, *args):
149156
for i in self._input_indices_to_copy:
150157
self._args[i].copy_(args[i])

torchtitan/experiments/graph_trainer/graph_utils.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,6 @@ def export_joint(
6767
torch.fx.traceback.preserve_node_meta(),
6868
):
6969
gm = dynamo_graph_capture_for_export(model)(*args, **kwargs)
70-
logger.debug("Dynamo gm:")
71-
logger.debug(
72-
gm.print_readable(
73-
print_output=False, include_stride=True, include_device=True
74-
)
75-
)
7670
_dump_gm(dump_folder, gm, "dynamo_gm")
7771

7872
tracing_context = gm.meta["tracing_context"]
@@ -288,10 +282,6 @@ def compiler(
288282
if passes is None:
289283
passes = DEFAULT_COMPILER_PASSES
290284

291-
logger.debug(f"{name} before compiler:")
292-
logger.debug(
293-
gm.print_readable(print_output=False, include_stride=True, include_device=True)
294-
)
295285
_dump_gm(dump_folder, gm, f"{name}_before_compiler")
296286

297287
if end_with_pass(passes, ["cudagraph_pass"]):
@@ -317,14 +307,18 @@ def compiler(
317307
# Only try to print/dump if gm is still a GraphModule
318308
# (compile_fx_inner returns a CompiledFxGraph which doesn't have print_readable)
319309
if hasattr(gm, "print_readable"):
320-
logger.debug(f"{name} after compiler:")
321-
logger.debug(
322-
gm.print_readable(
323-
print_output=False, include_stride=True, include_device=True
324-
)
325-
)
326310
_dump_gm(dump_folder, gm, f"{name}_after_compiler")
327311

312+
# Log the final transformed graph to tlparse.
313+
from torchtitan.experiments.graph_trainer.passes import tlparse_log_graph_pass
314+
315+
graph_name = (
316+
"aot_forward_graph_transformed"
317+
if is_forward
318+
else "aot_backward_graph_transformed"
319+
)
320+
tlparse_log_graph_pass(gm, example_inputs, graph_name=graph_name)
321+
328322
return gm
329323

330324

torchtitan/experiments/graph_trainer/passes.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from torch._inductor.fx_passes.overlap_manual_scheduling import manual_overlap_bucketing
2929
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
30+
from torch._logging import trace_structured
3031
from torch.fx.passes.regional_inductor import regional_inductor
3132
from torch.utils.checkpoint import CheckpointPolicy
3233

@@ -407,6 +408,44 @@ def reassign_to_pg_pass(
407408
return gm
408409

409410

411+
def tlparse_log_graph_pass(
412+
gm: torch.fx.GraphModule,
413+
example_inputs: Sequence[Any],
414+
*,
415+
graph_name: str,
416+
) -> torch.fx.GraphModule:
417+
"""Log the transformed graph to tlparse via trace_structured.
418+
419+
This pass should be added as the last transform in fwd/bwd_transforms
420+
so that the logged graph reflects all prior transformations.
421+
422+
Args:
423+
gm: The graph module to log.
424+
example_inputs: The example inputs (unused, required by protocol).
425+
graph_name: The name for this graph artifact
426+
(e.g. "aot_forward_graph_transformed").
427+
428+
Returns:
429+
The graph module unchanged.
430+
"""
431+
trace_structured(
432+
"artifact",
433+
metadata_fn=lambda: {
434+
"name": graph_name,
435+
"encoding": "string",
436+
},
437+
payload_fn=lambda: gm.print_readable(
438+
print_output=False,
439+
include_stride=True,
440+
include_device=True,
441+
expanded_def=True,
442+
),
443+
expect_trace_id=False,
444+
)
445+
446+
return gm
447+
448+
410449
# Registry mapping pass names to pass functions (for AOT mode fwd/bwd passes)
411450
AVAILABLE_COMPILER_PASSES = {
412451
"auto_bucketing": autobucketing_reordering_pass,

0 commit comments

Comments
 (0)