|
1 | 1 | """Transformation to the graph to render nicely in model_explorer.""" |
2 | 2 |
|
3 | | -import json |
4 | 3 | from typing import Tuple |
5 | 4 |
|
6 | | -import torch |
7 | 5 | import torch.export as te |
8 | 6 | from torch.fx import GraphModule |
9 | 7 |
|
10 | 8 | from ...models.factory import ModelFactory |
11 | 9 | from ...shim.interface import CachedSequenceInterface |
| 10 | +from ...utils.logger import ad_logger |
12 | 11 | from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry |
13 | 12 |
|
14 | 13 | try: |
15 | 14 | import model_explorer |
16 | | - from model_explorer.graph_builder import GraphNode, KeyValue, MetadataItem |
17 | | - from model_explorer.pytorch_exported_program_adater_impl import ( |
18 | | - PytorchExportedProgramAdapterImpl, |
19 | | - ) |
20 | 15 | except ImportError: |
21 | 16 | model_explorer = None |
22 | | - GraphNode = KeyValue = MetadataItem = PytorchExportedProgramAdapterImpl = None |
23 | | - # Optionally, you can log a warning or handle this gracefully elsewhere |
24 | | - |
25 | | - |
26 | | -def print_tensor(self, tensor: torch.Tensor, size_limit: int = 16): |
27 | | - shape = tensor.shape |
28 | | - total_size = 1 |
29 | | - for dim in shape: |
30 | | - total_size *= dim |
31 | | - |
32 | | - if size_limit < 0 or size_limit >= total_size: |
33 | | - return json.dumps(tensor.cpu().detach().to(torch.float32).numpy().tolist()) |
34 | | - |
35 | | - return json.dumps( |
36 | | - (tensor.cpu().detach().to(torch.float32).numpy().flatten())[:size_limit].tolist() |
37 | | - ) |
38 | | - |
39 | | - |
40 | | -def _get_shape(val): |
41 | | - return json.dumps( |
42 | | - list( |
43 | | - map( |
44 | | - lambda x: int(x) if str(x).isdigit() else str(x), |
45 | | - val.shape, |
46 | | - ) |
47 | | - ) |
48 | | - ) |
49 | | - |
50 | | - |
51 | | -def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode): |
52 | | - out_vals = fx_node.meta.get("val") |
53 | | - if out_vals is None: |
54 | | - return |
55 | | - |
56 | | - if isinstance(out_vals, (tuple, list)): |
57 | | - for idx, val in enumerate(out_vals): |
58 | | - metadata = MetadataItem(id=str(idx), attrs=[]) |
59 | | - if val is None: |
60 | | - continue |
61 | | - dtype = str(val.dtype) |
62 | | - shape = _get_shape(val) |
63 | | - metadata.attrs.append(KeyValue(key="tensor_shape", value=dtype + shape)) |
64 | | - node.outputsMetadata.append(metadata) |
65 | | - elif isinstance(out_vals, torch.Tensor): |
66 | | - dtype = str(out_vals.dtype) |
67 | | - shape = _get_shape(out_vals) |
68 | | - metadata = MetadataItem(id="0", attrs=[KeyValue(key="tensor_shape", value=dtype + shape)]) |
69 | | - node.outputsMetadata.append(metadata) |
70 | | - elif isinstance(out_vals, bool): |
71 | | - metadata = MetadataItem(id="0", attrs=[KeyValue(key="tensor_shape", value="bool[1]")]) |
72 | | - node.outputsMetadata.append(metadata) |
73 | | - else: |
74 | | - raise ValueError(f"Unsupported output type: {type(out_vals)}") |
75 | | - |
76 | | - |
77 | | -# TODO(yudong): make custom_ops configurable |
78 | | -CUSTOM_OPS = ( |
79 | | - torch.ops.auto_deploy.torch_dist_all_reduce.default, |
80 | | - torch.ops.auto_deploy.trtllm_dist_all_reduce.default, |
81 | | - torch.ops.aten.slice.Tensor, |
82 | | - torch.ops.auto_deploy.triton_attention_fused_mha_with_cache.default, |
83 | | - torch.ops.auto_deploy.torch_linear_simple.default, |
84 | | - torch.ops.aten.split_with_sizes.default, |
85 | | -) |
86 | 17 |
|
87 | 18 |
|
88 | 19 | @TransformRegistry.register("visualize_namespace") |
89 | 20 | class VisualizeNamespace(BaseTransform): |
| 21 | + """Transform to visualize the graph using Model Explorer. |
| 22 | +
|
| 23 | + This transform exports the graph module to an ExportedProgram and launches |
| 24 | + Model Explorer for interactive visualization. The visualization helps debug |
| 25 | + and understand the graph structure after AutoDeploy transformations. |
| 26 | + """ |
| 27 | + |
90 | 28 | def _apply( |
91 | 29 | self, |
92 | 30 | gm: GraphModule, |
93 | 31 | cm: CachedSequenceInterface, |
94 | 32 | factory: ModelFactory, |
95 | 33 | shared_config: SharedConfig, |
96 | 34 | ) -> Tuple[GraphModule, TransformInfo]: |
97 | | - PytorchExportedProgramAdapterImpl.print_tensor = print_tensor |
98 | | - PytorchExportedProgramAdapterImpl.add_outputs_metadata = add_outputs_metadata |
| 35 | + """Export the graph and launch Model Explorer for visualization. |
| 36 | +
|
| 37 | + Args: |
| 38 | + gm: The graph module to visualize. |
| 39 | + cm: The cached sequence interface with input arguments. |
| 40 | + factory: The model factory (unused). |
| 41 | + shared_config: Shared configuration across transforms (unused). |
| 42 | +
|
| 43 | + Returns: |
| 44 | + A tuple of the unchanged graph module and transform info indicating |
| 45 | + whether visualization was successful or skipped. |
| 46 | + """ |
| 47 | + if model_explorer is None: |
| 48 | + return gm, TransformInfo( |
| 49 | + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True |
| 50 | + ) |
| 51 | + |
| 52 | + try: |
| 53 | + # Export graph module to ExportedProgram for visualization |
| 54 | + exported_program = te.export(gm, args=(), kwargs=cm.named_args, dynamic_shapes=None) |
99 | 55 |
|
100 | | - # TODO(yudong): make viz as non-block call. |
101 | | - ep = te.export(gm, args=cm.args, dynamic_shapes=cm.dynamic_shapes) |
102 | | - graph = ep.graph |
103 | | - # Ensure the ops land up in the right module for better viz |
104 | | - for n in graph.nodes: |
105 | | - if n.target in CUSTOM_OPS: |
106 | | - n.meta["nn_module_stack"] = n.args[0].meta["nn_module_stack"] |
| 56 | + ad_logger.info("Launching Model Explorer visualization...") |
| 57 | + model_explorer.visualize_pytorch("model-viz", exported_program) |
107 | 58 |
|
108 | | - model_explorer.visualize_pytorch("model-viz", ep) |
| 59 | + return gm, TransformInfo( |
| 60 | + skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True |
| 61 | + ) |
109 | 62 |
|
110 | | - return gm, TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) |
| 63 | + except Exception as e: |
| 64 | + ad_logger.error(f"Failed to visualize graph with Model Explorer: {e}") |
| 65 | + # Don't fail the pipeline if visualization fails |
| 66 | + return gm, TransformInfo( |
| 67 | + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True |
| 68 | + ) |
0 commit comments