Skip to content

Commit 8a350ad

Browse files
authored
feat: Improve logging throughout the Dynamo path (#2405)
1 parent 30847c8 commit 8a350ad

File tree

6 files changed

+102
-5
lines changed

6 files changed

+102
-5
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
313313
# If specified, try using the fast partitioner and fall back to the global one on failure
314314
if settings.use_fast_partitioner:
315315
try:
316+
logger.info("Partitioning the graph via the fast partitioner")
316317
partitioned_module, supported_ops = partitioning.fast_partition(
317318
gm,
318319
verbose=settings.debug,
@@ -322,14 +323,15 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
322323
except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
323324
logger.error(
324325
"Partitioning failed on the subgraph with fast partition. See trace above. "
325-
+ "Retrying with global partition.",
326+
"Retrying with global partition.",
326327
exc_info=True,
327328
)
328329

329330
fast_partitioner_failed = True
330331
settings.use_fast_partitioner = False
331332

332333
if not settings.use_fast_partitioner:
334+
logger.info("Partitioning the graph via the global partitioner")
333335
partitioned_module, supported_ops = partitioning.global_partition(
334336
gm,
335337
verbose=settings.debug,
@@ -367,14 +369,15 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
367369
# Get the submodule inputs for min, opt, max shapes of the graph inputs
368370
submodule_inputs = partitioning.construct_submodule_inputs(submodule)
369371

372+
assert submodule_inputs is not None
373+
370374
logger.debug(
371-
"Submodule name: %s\n Input shapes: %s\n %s",
375+
"Converting submodule: %s\n Input shapes: %s\n %s",
372376
str(name),
373377
[input.shape for input in submodule_inputs],
374378
str(submodule.graph),
375379
)
376380

377-
assert submodule_inputs is not None
378381
# Handle long/double inputs if requested by the user
379382
if settings.truncate_double:
380383
submodule_inputs = repair_double_inputs(

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import warnings
44
from datetime import datetime
5-
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
5+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple
66

77
import numpy as np
88
import tensorrt as trt
@@ -21,6 +21,7 @@
2121
)
2222
from torch_tensorrt.dynamo.conversion._ConverterRegistry import CallingConvention
2323
from torch_tensorrt.dynamo.conversion.converter_utils import (
24+
get_node_io,
2425
get_node_name,
2526
get_trt_tensor,
2627
)
@@ -106,6 +107,9 @@ def __init__(
106107
[dtype._from(o) for o in output_dtypes] if output_dtypes else None
107108
)
108109

110+
# Mapping of constants to shapes and dtypes
111+
self.const_mapping: Dict[str, Tuple[Sequence[int], str]] = {}
112+
109113
def validate_conversion(self) -> Set[str]:
110114
missing_converters: Set[str] = set()
111115

@@ -361,8 +365,19 @@ def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
361365
n.kwargs = kwargs
362366

363367
# run the node
368+
_LOGGER.debug(
369+
f"Running node {self._cur_node_name}, a {self._cur_node.op} node "
370+
f"with target {self._cur_node.target} in the TensorRT Interpreter"
371+
)
364372
trt_node: torch.fx.Node = super().run_node(n)
365373

374+
if n.op == "get_attr":
375+
self.const_mapping[str(n)] = (tuple(trt_node.shape), str(trt_node.dtype))
376+
377+
_LOGGER.debug(
378+
f"Ran node {self._cur_node_name} with properties: {get_node_io(n, self.const_mapping)}"
379+
)
380+
366381
# remove "_itensor_to_tensor_meta"
367382
kwargs = dict(n.kwargs)
368383
del kwargs["_itensor_to_tensor_meta"]

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections
12
import functools
23
import logging
34
import re
@@ -8,6 +9,7 @@
89
import torch
910
import torch_tensorrt.dynamo.conversion.impl as impl
1011
from torch.fx.node import Argument, Target
12+
from torch.fx.passes.shape_prop import TensorMetadata
1113
from torch_tensorrt import _enums
1214
from torch_tensorrt.dynamo._SourceIR import SourceIR
1315
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -44,6 +46,75 @@ def get_node_name(node: torch.fx.Node) -> str:
4446
return node_name
4547

4648

49+
def get_node_io(
50+
node: torch.fx.Node, constant_mapping: Dict[str, Tuple[Sequence[int], str]]
51+
) -> str:
52+
"""Gets a string representing the node inputs and outputs including tensor shapes and dtypes"""
53+
54+
def format_tensor_metadata(metadata: Union[Any, Sequence[Any]]) -> str:
55+
"""Formats the metadata for a single node"""
56+
# If the provided data is a simple TensorMetadata object, parse it
57+
if isinstance(metadata, TensorMetadata) or issubclass(
58+
type(metadata), torch.Tensor
59+
):
60+
return f"{tuple(metadata.shape)}@{metadata.dtype}" # type: ignore
61+
# If the provided data is a scalar, return it as is
62+
elif isinstance(metadata, (int, float, bool)):
63+
return f"{metadata}@Python-{type(metadata)}"
64+
# If the provided data is a sequence, recursively parse it
65+
elif isinstance(metadata, collections.abc.Sequence):
66+
formatted_str = "("
67+
for meta in metadata:
68+
formatted_str += format_tensor_metadata(meta) + ", "
69+
70+
return formatted_str[:-2] + ")"
71+
else:
72+
_LOGGER.warning(
73+
f"Detected unparseable type in node formatting: {type(metadata)}"
74+
)
75+
return ""
76+
77+
# Format input tensors
78+
metadata_string = "Inputs: ("
79+
80+
# For each input argument, format it accordingly
81+
for arg in node.args:
82+
if isinstance(arg, torch.fx.Node):
83+
if arg.op == "get_attr":
84+
shape, dtype = constant_mapping[str(arg)]
85+
arg_repr = f"{shape}@{dtype}"
86+
elif arg.meta.get("tensor_meta") is not None:
87+
arg_repr = format_tensor_metadata(arg.meta["tensor_meta"])
88+
elif arg.meta.get("val") is not None:
89+
arg_repr = format_tensor_metadata(arg.meta["val"])
90+
else:
91+
arg_repr = ""
92+
93+
metadata_string += f"{arg}: {arg_repr}, "
94+
else:
95+
metadata_string += f"{arg}, "
96+
97+
metadata_string = (
98+
metadata_string[:-2] if metadata_string[-1] != "(" else metadata_string
99+
) + ")"
100+
101+
# Format output tensors and arguments
102+
metadata_string += " | Outputs: ("
103+
if node.op == "get_attr":
104+
shape, dtype = constant_mapping[str(node)]
105+
node_repr = f"{shape}@{dtype}"
106+
elif node.meta.get("tensor_meta") is not None:
107+
node_repr = format_tensor_metadata(node.meta["tensor_meta"])
108+
elif node.meta.get("val") is not None:
109+
node_repr = format_tensor_metadata(node.meta["val"])
110+
else:
111+
node_repr = ""
112+
metadata_string += f"{node}: {node_repr}, "
113+
metadata_string = metadata_string[:-2] + ")"
114+
115+
return metadata_string
116+
117+
47118
def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool:
48119
"""Detects whether a call_function node is the only operator on a placeholder"""
49120
# Returns true if the node operates on a placeholder and is a direct output

py/torch_tensorrt/dynamo/conversion/truncate_double.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import logging
34
from typing import Optional, Sequence, Set
45

56
import torch
@@ -8,6 +9,8 @@
89
from torch_tensorrt._Input import Input
910
from torch_tensorrt.dynamo.utils import get_torch_inputs
1011

12+
logger = logging.getLogger(__name__)
13+
1114

1215
def _extract_downstream_get_nodes(
1316
module_node: torch.fx.Node, output_indices: Set[int]
@@ -62,6 +65,10 @@ def _repair_64bit_input(
6265
torch.float64,
6366
), f"dtype argument must be torch.float64, got {dtype}"
6467

68+
logger.info(
69+
f"Downcasting a 64-bit input at position {position} of submodule {submodule_name}"
70+
)
71+
6572
# Determine target data type in 32 and 64 bit forms
6673
dtype_64bit = dtype
6774
dtype_32bit = torch.float32

py/torch_tensorrt/dynamo/lowering/passes/replace_max_pool_with_indices.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def replace_max_pool_with_indices(
4343
args=node.args,
4444
kwargs=node.kwargs,
4545
)
46+
maxpool_fused.meta = node.meta
4647

4748
logger.debug(
4849
f"Replacing all uses of nodes {node}, {getitem_node} with fused maxpool node {maxpool_fused} "

py/torch_tensorrt/dynamo/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def parse_complex_tensor_structs(
191191

192192
else:
193193
raise ValueError(
194-
f"Invalid input type {type(inputs)} encountered in parse_complex_tensor_structs parsing. "
194+
f"Invalid input type {type(inputs)} encountered during Dynamo input parsing. "
195195
+ "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}"
196196
)
197197

0 commit comments

Comments
 (0)