Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ Tutorials
* :ref:`mutable_torchtrt_module_example`
* :ref:`weight_streaming_example`
* :ref:`pre_allocated_output_example`
* :ref:`debugger_example`

.. toctree::
:caption: Tutorials
Expand Down
1 change: 1 addition & 0 deletions examples/dynamo/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ Model Zoo
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`)
* :ref:`debugger_example`: Debugging Torch-TensorRT Compilation
75 changes: 75 additions & 0 deletions examples/dynamo/debugger_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
.. _debugger_example:

Debugging Torch-TensorRT Compilation
===================================================================

TensorRT conversion can perform many graph transformations and backend specific
optimizations that are sometimes hard to inspect. Torch-TensorRT provides a
Debugger utility to help visualize FX graphs around lowering passes, monitor
engine building, and capture profiling or TensorRT API traces.

In this example, we demonstrate how to:

1. Enable the Torch-TensorRT Debugger context
2. Capture and visualize FX graphs before and/or after specific lowering passes
3. Configure logging directory and verbosity
"""

import os
import tempfile

import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models

temp_dir = os.path.join(tempfile.gettempdir(), "torch_tensorrt_debugger_example")

np.random.seed(0)
torch.manual_seed(0)
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]


model = models.resnet18(pretrained=False).to("cuda").eval()
exp_program = torch.export.export(model, tuple(inputs))
enabled_precisions = {torch.float}
workspace_size = 20 << 30
min_block_size = 0
use_python_runtime = False
torch_executed_ops = {}

with torch_trt.dynamo.Debugger(
log_level="debug",
logging_dir=temp_dir,
engine_builder_monitor=False, # whether to monitor the engine building process
capture_fx_graph_after=[
"complex_graph_detection"
], # fx graph visualization after certain lowering pass
capture_fx_graph_before=[
"remove_detach"
], # fx graph visualization before certain lowering pass
):

trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
immutable_weights=False,
reuse_cached_engines=False,
)

trt_output = trt_gm(*inputs)


"""
The logging directory will contain the following files:
- /tmp/torch_tensorrt_debugger_example/
torch_tensorrt_logging.log
- /lowering_passes_visualization/
after_complex_graph_detection.svg
before_remove_detach.svg
"""
Loading