You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docsrc/tutorials/compile_groot.rst
+28-2Lines changed: 28 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -39,7 +39,7 @@ Torch-TensorRT supports both just-in-time (JIT) compilation via the torch.compil
39
39
By applying a series of graph-level and kernel-level optimizations—including layer fusion, kernel auto-tuning, precision calibration, and dynamic tensor shape handling—Torch-TensorRT produces a specialized TensorRT engine tailored to the target GPU architecture. These optimizations maximize inference throughput and minimize latency, delivering substantial performance gains across both datacenter and edge platforms.
40
40
Torch-TensorRT is designed to operate seamlessly across a wide spectrum of NVIDIA hardware, ranging from high-performance datacenter GPUs (e.g., A100, H100, DGX Spark) to resource-constrained edge devices such as Jetson Thor. This versatility allows developers to deploy the same model efficiently across heterogeneous environments without modifying core code.
41
41
42
-
A key component of this integration is the MutableTorchTensorRTModule (MTTM) — a module provided by Torch-TensorRT. MTTM functions as a transparent and dynamic wrapper around standard PyTorch modules. It automatically intercepts and optimizes the module’s forward() function on-the-fly using TensorRT, while preserving the complete semantics and functionality of the original PyTorch model. This design ensures drop-in compatibility, enabling easy integration of Torch-TensorRT acceleration into complex frameworks, such as multi-stage inference pipelines or Hugging Face Transformers architectures, with minimal code changes.
42
+
A key component of this compiler is the MutableTorchTensorRTModule (MTTM) — a subclass of ``torch.nn.Module`` provided by Torch-TensorRT. MTTM functions as a transparent and dynamic wrapper around standard PyTorch modules. It automatically intercepts and optimizes the module’s forward() function on-the-fly using TensorRT, while preserving the complete semantics and functionality of the original PyTorch model. This design ensures drop-in compatibility, enabling easy integration of Torch-TensorRT acceleration into complex frameworks, such as multi-stage inference pipelines or Hugging Face Transformers architectures, with minimal code changes.
43
43
44
44
Within the GR00T N1.5 model, each component is wrapped with MTTM to achieve optimized performance across all compute stages. This modular wrapping approach simplifies benchmarking, and selective optimization, ensuring that each subcomponent (e.g., the vision, language, or action head modules) benefits from TensorRT’s runtime-level acceleration.
45
45
@@ -66,6 +66,32 @@ The ``fn_name`` argument allows users to target specific submodules of the GR00T
66
66
Results indicate that Torch-TensorRT achieves performance levels comparable to ONNX-TensorRT on the GR00T N1.5 model. However, certain submodules, particularly the LLM component still present optimization opportunities to fully match ONNX-TensorRT performance
67
67
Support for Torch-TensorRT is currently available in this `PR <https://github.com/NVIDIA/Isaac-GR00T/pull/419>`_ and will be merged. Results indicate that Torch-TensorRT achieves performance levels comparable to ONNX-TensorRT on the GR00T N1.5 model. However, certain submodules, particularly the LLM component still present optimization opportunities to fully match ONNX-TensorRT performance
68
68
69
+
VLA Optimizations
70
+
--------------------
71
+
72
+
The following optimizations have been applied to the components of the GR00T N1.5 model to improve performance using Torch-TensorRT:
73
+
74
+
1) Vision Transformer (ViT)
75
+
- The ViT component is optimized by using the Torch-TensorRT MutableTorchTensorRTModule (MTTM). TensorRT optimizations include layer fusion, kernel auto-tuning, dynamic shape handling.
76
+
- FP8 quantization support is available to reduce model size and improve performance.
77
+
- For the SiglipVisionModel, the ``SiglipMultiheadAttentionPoolingHead`` of the ViT component is disabled to eliminate unnecessary latency overhead, as this layer is not utilized by the downstream model. See the implementation `here <https://github.com/peri044/Isaac-GR00T/blob/6b34a65e02b07b19d689498ec75066792b4bb738/deployment_scripts/run_groot_torchtrt.py#L258-L261>`_.
78
+
79
+
2) Text Transformer (LLM)
80
+
- MTTM support for the LLM component and similar TensorRT optimizations apply.
81
+
- FP8 quantization support is available to reduce model size and improve performance.
82
+
83
+
3) Flow Matching Action Head
84
+
- The Flow Matching Action Head component consists of 5 different components: VLM backbone processor, State encoder, Action encoder and Action decoder and DiT.
85
+
- VLM backbone processor uses a LayerNorm layer and Self-Attention Transformer network to process the outputs of VLM (ViT + LLM). We merge these two components into a single ``torch.nn.Module`` to minimize graph fragmentation and improve performance. See the implementation `here <https://github.com/peri044/Isaac-GR00T/blob/6b34a65e02b07b19d689498ec75066792b4bb738/deployment_scripts/run_groot_torchtrt.py#L485-L506>`_.
86
+
- State encoder, Action encoder and Action Decoder use a Multi-Layer Perceptron (MLP) like networks to encode the state and action vectors. These are wrapped with MTTM and standard TensorRT optimizations apply.
87
+
- DiT is a Diffusion-Transformer model that is used to generate the action vector. It is wrapped with MTTM and standard TensorRT optimizations apply. FP8 quantization support is available to this component.
88
+
89
+
4) Dynamic Shape Management
90
+
- A general optimization that provides performance improvements is using dynamic shapes only when necessary. In the GR00T N1.5 model, dynamic shapes are applied selectively to the LLM and DiT components where input dimensions may vary.
91
+
- For components with predictable input sizes, fixed batch dimensions are preferred. Specifying a batch size as dynamic can reduce performance compared to a fixed batch size when the dimensions are known in advance, as TensorRT can apply more aggressive optimizations with static shapes.
92
+
93
+
While these optimizations have been specifically applied to the GR00T N1.5 model, many of them are generalizable to other Vision-Language-Action (VLA) models. Techniques such as selective dynamic shape management, component-level MTTM wrapping, and FP8 quantization can be adapted to similar architectures to achieve comparable performance improvements.
94
+
69
95
RoboCasa Simulation
70
96
--------------------
71
97
@@ -85,7 +111,7 @@ You can then use the following command to start the simulation:
This would start the simulation, display the success rate and record the videos in ``videos`` directory.
114
+
This would start the simulation, display the success rate and record the videos in ``videos`` directory. Here is a sample video of the GR00T N1.5 model in the simulation:
89
115
90
116
.. note::
91
117
If you are running Isaac GR00T in a Docker environment, you can create two separate tmux sessions and launch both Docker containers on the same network to enable inter-container communication. This allows the inference service and simulation service to communicate seamlessly across containers.
0 commit comments