Skip to content

Commit ce06632

Browse files
fix: update visualization and update transform metadata handling
Signed-off-by: Karthik Vetrivel <kvetrivel@nvidia.com>
1 parent ea380ff commit ce06632

File tree

2 files changed

+40
-82
lines changed

2 files changed

+40
-82
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ transforms:
149149
############################################################################################
150150
visualize_namespace:
151151
stage: visualize
152-
enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/8460
152+
enabled: false
153153
############################################################################################
154154
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
155155
############################################################################################
Lines changed: 39 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,110 +1,68 @@
11
"""Transformation to the graph to render nicely in model_explorer."""
22

3-
import json
43
from typing import Tuple
54

6-
import torch
75
import torch.export as te
86
from torch.fx import GraphModule
97

108
from ...models.factory import ModelFactory
119
from ...shim.interface import CachedSequenceInterface
10+
from ...utils.logger import ad_logger
1211
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
1312

1413
try:
1514
import model_explorer
16-
from model_explorer.graph_builder import GraphNode, KeyValue, MetadataItem
17-
from model_explorer.pytorch_exported_program_adater_impl import (
18-
PytorchExportedProgramAdapterImpl,
19-
)
2015
except ImportError:
2116
model_explorer = None
22-
GraphNode = KeyValue = MetadataItem = PytorchExportedProgramAdapterImpl = None
23-
# Optionally, you can log a warning or handle this gracefully elsewhere
24-
25-
26-
def print_tensor(self, tensor: torch.Tensor, size_limit: int = 16):
27-
shape = tensor.shape
28-
total_size = 1
29-
for dim in shape:
30-
total_size *= dim
31-
32-
if size_limit < 0 or size_limit >= total_size:
33-
return json.dumps(tensor.cpu().detach().to(torch.float32).numpy().tolist())
34-
35-
return json.dumps(
36-
(tensor.cpu().detach().to(torch.float32).numpy().flatten())[:size_limit].tolist()
37-
)
38-
39-
40-
def _get_shape(val):
41-
return json.dumps(
42-
list(
43-
map(
44-
lambda x: int(x) if str(x).isdigit() else str(x),
45-
val.shape,
46-
)
47-
)
48-
)
49-
50-
51-
def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode):
52-
out_vals = fx_node.meta.get("val")
53-
if out_vals is None:
54-
return
55-
56-
if isinstance(out_vals, (tuple, list)):
57-
for idx, val in enumerate(out_vals):
58-
metadata = MetadataItem(id=str(idx), attrs=[])
59-
if val is None:
60-
continue
61-
dtype = str(val.dtype)
62-
shape = _get_shape(val)
63-
metadata.attrs.append(KeyValue(key="tensor_shape", value=dtype + shape))
64-
node.outputsMetadata.append(metadata)
65-
elif isinstance(out_vals, torch.Tensor):
66-
dtype = str(out_vals.dtype)
67-
shape = _get_shape(out_vals)
68-
metadata = MetadataItem(id="0", attrs=[KeyValue(key="tensor_shape", value=dtype + shape)])
69-
node.outputsMetadata.append(metadata)
70-
elif isinstance(out_vals, bool):
71-
metadata = MetadataItem(id="0", attrs=[KeyValue(key="tensor_shape", value="bool[1]")])
72-
node.outputsMetadata.append(metadata)
73-
else:
74-
raise ValueError(f"Unsupported output type: {type(out_vals)}")
75-
76-
77-
# TODO(yudong): make custom_ops configurable
78-
CUSTOM_OPS = (
79-
torch.ops.auto_deploy.torch_dist_all_reduce.default,
80-
torch.ops.auto_deploy.trtllm_dist_all_reduce.default,
81-
torch.ops.aten.slice.Tensor,
82-
torch.ops.auto_deploy.triton_attention_fused_mha_with_cache.default,
83-
torch.ops.auto_deploy.torch_linear_simple.default,
84-
torch.ops.aten.split_with_sizes.default,
85-
)
8617

8718

8819
@TransformRegistry.register("visualize_namespace")
8920
class VisualizeNamespace(BaseTransform):
21+
"""Transform to visualize the graph using Model Explorer.
22+
23+
This transform exports the graph module to an ExportedProgram and launches
24+
Model Explorer for interactive visualization. The visualization helps debug
25+
and understand the graph structure after AutoDeploy transformations.
26+
"""
27+
9028
def _apply(
9129
self,
9230
gm: GraphModule,
9331
cm: CachedSequenceInterface,
9432
factory: ModelFactory,
9533
shared_config: SharedConfig,
9634
) -> Tuple[GraphModule, TransformInfo]:
97-
PytorchExportedProgramAdapterImpl.print_tensor = print_tensor
98-
PytorchExportedProgramAdapterImpl.add_outputs_metadata = add_outputs_metadata
35+
"""Export the graph and launch Model Explorer for visualization.
36+
37+
Args:
38+
gm: The graph module to visualize.
39+
cm: The cached sequence interface with input arguments.
40+
factory: The model factory (unused).
41+
shared_config: Shared configuration across transforms (unused).
42+
43+
Returns:
44+
A tuple of the unchanged graph module and transform info indicating
45+
whether visualization was successful or skipped.
46+
"""
47+
if model_explorer is None:
48+
return gm, TransformInfo(
49+
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
50+
)
51+
52+
try:
53+
# Export graph module to ExportedProgram for visualization
54+
exported_program = te.export(gm, args=(), kwargs=cm.named_args, dynamic_shapes=None)
9955

100-
# TODO(yudong): make viz as non-block call.
101-
ep = te.export(gm, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
102-
graph = ep.graph
103-
# Ensure the ops land up in the right module for better viz
104-
for n in graph.nodes:
105-
if n.target in CUSTOM_OPS:
106-
n.meta["nn_module_stack"] = n.args[0].meta["nn_module_stack"]
56+
ad_logger.info("Launching Model Explorer visualization...")
57+
model_explorer.visualize_pytorch("model-viz", exported_program)
10758

108-
model_explorer.visualize_pytorch("model-viz", ep)
59+
return gm, TransformInfo(
60+
skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True
61+
)
10962

110-
return gm, TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
63+
except Exception as e:
64+
ad_logger.error(f"Failed to visualize graph with Model Explorer: {e}")
65+
# Don't fail the pipeline if visualization fails
66+
return gm, TransformInfo(
67+
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
68+
)

0 commit comments

Comments
 (0)