diff --git a/docsrc/index.rst b/docsrc/index.rst index 03a0a01f0a..6ced784097 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -71,6 +71,7 @@ Tutorials * :ref:`mutable_torchtrt_module_example` * :ref:`weight_streaming_example` * :ref:`pre_allocated_output_example` +* :ref:`debugger_example` .. toctree:: :caption: Tutorials diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index c4d2baf0e4..b0bde266d8 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -22,3 +22,4 @@ Model Zoo * :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`) * :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`) * :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`) +* :ref:`debugger_example`: Debugging Torch-TensorRT Compilation \ No newline at end of file diff --git a/examples/dynamo/debugger_example.py b/examples/dynamo/debugger_example.py new file mode 100644 index 0000000000..7f6a3fe92b --- /dev/null +++ b/examples/dynamo/debugger_example.py @@ -0,0 +1,75 @@ +""" +.. _debugger_example: + +Debugging Torch-TensorRT Compilation +=================================================================== + +TensorRT conversion can perform many graph transformations and backend specific +optimizations that are sometimes hard to inspect. Torch-TensorRT provides a +Debugger utility to help visualize FX graphs around lowering passes, monitor +engine building, and capture profiling or TensorRT API traces. + +In this example, we demonstrate how to: + + 1. Enable the Torch-TensorRT Debugger context + 2. Capture and visualize FX graphs before and/or after specific lowering passes + 3. Configure logging directory and verbosity +""" + +import os +import tempfile + +import numpy as np +import torch +import torch_tensorrt as torch_trt +import torchvision.models as models + +temp_dir = os.path.join(tempfile.gettempdir(), "torch_tensorrt_debugger_example") + +np.random.seed(0) +torch.manual_seed(0) +inputs = [torch.rand((1, 3, 224, 224)).to("cuda")] + + +model = models.resnet18(pretrained=False).to("cuda").eval() +exp_program = torch.export.export(model, tuple(inputs)) +enabled_precisions = {torch.float} +workspace_size = 20 << 30 +min_block_size = 0 +use_python_runtime = False +torch_executed_ops = {} + +with torch_trt.dynamo.Debugger( + log_level="debug", + logging_dir=temp_dir, + engine_builder_monitor=False, # whether to monitor the engine building process + capture_fx_graph_after=[ + "complex_graph_detection" + ], # fx graph visualization after certain lowering pass + capture_fx_graph_before=[ + "remove_detach" + ], # fx graph visualization before certain lowering pass +): + + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + enabled_precisions=enabled_precisions, + min_block_size=min_block_size, + torch_executed_ops=torch_executed_ops, + immutable_weights=False, + reuse_cached_engines=False, + ) + + trt_output = trt_gm(*inputs) + + +""" +The logging directory will contain the following files: +- /tmp/torch_tensorrt_debugger_example/ + torch_tensorrt_logging.log + - /lowering_passes_visualization/ + after_complex_graph_detection.svg + before_remove_detach.svg +"""