Skip to content

Commit 3feea29

Browse files
mdbarnesUCSDpytorchmergebot
authored andcommitted
torch.fx: add debug-level logging to Interpreter.run_node (pytorch#117351) (pytorch#166622)
### Summary Adds a debug-level logging statement to torch.fx.Interpreter.run_node, as proposed in [pytorch#117351](pytorch#117351), to make FX graph execution traceable when debugging or instrumenting model transformations. When debug logging is enabled, each executed node emits a single structured log line formatted via `LazyString(lambda: n.format_node())`, deferring string construction unless logging is active. ### Example Output With `logging.DEBUG` enabled: ``` run_node x = x() run_node add = _operator.add(x, 1) run_node clamp = torch.clamp(add, min=0.0, max=5.0) run_node output = output(clamp) ``` With `logging.DEBUG` disabled no additional output is produced (unchanged default behavior). ### Test Plan Verified locally with Python 3.11 on macOS using a PyTorch build from source. - With `logging.DEBUG` enabled: each node emits a debug log via LazyString. - With `logging.DEBUG` disabled: no additional output. - Confirmed all `Interpreter` tests pass locally: `pytest test/test_fx.py -k "Interpreter"` Updated the example output to reflect the new `_format_fx_node` helper and inclusion of `kwargs`. Pull Request resolved: pytorch#166622 Approved by: https://github.com/aorenste
1 parent c3c3653 commit 3feea29

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

torch/fx/interpreter.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# mypy: allow-untyped-defs
22
import inspect
3+
import logging
34
from contextlib import contextmanager
45
from typing import Any, Optional, TYPE_CHECKING, Union
56

67
import torch
78
import torch.fx.traceback as fx_traceback
8-
from torch._logging import trace_structured
9+
from torch._logging import LazyString, trace_structured
910
from torch.hub import tqdm
1011

1112
from . import config
@@ -21,10 +22,35 @@
2122
if TYPE_CHECKING:
2223
from collections.abc import Iterator
2324

25+
log = logging.getLogger(__name__)
2426

2527
__all__ = ["Interpreter", "Transformer"]
2628

2729

30+
def _format_fx_node(n):
31+
"""
32+
Format a torch.fx.Node into a human-readable string for debug logging.
33+
34+
Args:
35+
n (torch.fx.Node): The FX node being executed.
36+
37+
Returns:
38+
str: A formatted string describing the node operation, including its
39+
name, target, positional arguments, and keyword arguments.
40+
"""
41+
module_prefix = getattr(n.target, "__module__", "")
42+
module_prefix = f"{module_prefix}." if module_prefix else ""
43+
44+
# Handle positional and keyword arguments
45+
args = ", ".join(map(str, n.args))
46+
kwargs = ", ".join(f"{k}={v}" for k, v in n.kwargs.items())
47+
joined = ", ".join(filter(None, [args, kwargs]))
48+
49+
return (
50+
f"{n.name} = {module_prefix}{getattr(n.target, '__name__', n.target)}({joined})"
51+
)
52+
53+
2854
@compatibility(is_backward_compatible=True)
2955
class Interpreter:
3056
"""
@@ -261,6 +287,7 @@ def run_node(self, n: Node) -> Any:
261287
Returns:
262288
Any: The result of executing ``n``
263289
"""
290+
log.debug("run_node %s", LazyString(lambda: _format_fx_node(n)))
264291
with self._set_current_node(n):
265292
args, kwargs = self.fetch_args_kwargs_from_env(n)
266293
assert isinstance(args, tuple)

0 commit comments

Comments
 (0)