Skip to content

DINOv2 (Vision Transformer) FMHA Fusion Failure in TensorRT 10.8.0.43 with PyTorch 2.7.1 ONNX Export (with GPU 3080) #4537

@CVKim

Description

@CVKim

My goal is to optimize Transformer-based models, specifically a DINOv2 model, by converting it from ONNX to TensorRT to apply FMHA (Fused Multi-Head Attention) for performance optimization.

However, from the very creation of the ONNX model (i.e., at the PyTorch export stage), the expected MHA-related layers (indicating FMHA fusion) are not being generated in the ONNX graph.

I have conducted experiments using the Python code provided below, where I attempt to generate ONNX models both with and without explicit use_flash_attention enabled via torch.backends.cuda.sdp_kernel. In both cases, the desired FMHA-fused layers are not observed.

Environment

  • TensorRT Version: 10.8.0.43
  • NVIDIA GPU: (Please fill your exact GPU model here, e.g., NVIDIA H200 (Hopper architecture - SM90))
  • NVIDIA Driver Version: (Please fill your exact driver version here, e.g., 550.127.08)
  • CUDA Version: 12.8
  • cuDNN Version: (Please fill your exact cuDNN version here, e.g., 9.x.x)
  • Operating System: (Please fill your OS here, e.g., Windows 11, Linux)
  • Python Version: (Please fill your Python version here, e.g., 3.11.12)
  • PyTorch Version: (Please fill your PyTorch version here, e.g., 2.7.1)
  • Baremetal or Container: (Please fill if baremetal or container, e.g., Baremetal, Docker, WSL2)

Relevant Files

  • Python ONNX Export Script:

    import torch
    import argparse
    
    class DinoV2Wrapper(torch.nn.Module):
        def __init__(self, model, use_flash_attention):
            super().__init__()
            self.model = model
            self.use_flash_attention = use_flash_attention
    
        def forward(self, x):
            with torch.backends.cuda.sdp_kernel(
                enable_flash=self.use_flash_attention,
                enable_math=not self.use_flash_attention,
                enable_mem_efficient=False
            ):
                return self.model(x)
    
    if __name__ == "__main__":
        parser = argparse.ArgumentParser()
        parser.add_argument("--disable_flash", action="store_true")
        args = parser.parse_args()
    
        use_flash = not args.disable_flash
    
        print("Loading DINOv2 ViT-L/14 model...")
        # Assuming dinov2 is loadable, e.g., via torch.hub.load
        model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14").eval().cuda().half()
        print("Model loaded.")
    
        wrapped_model = DinoV2Wrapper(model, use_flash_attention=use_flash)
    
        dummy_input = torch.randn(1, 3, 224, 224, device="cuda", dtype=torch.float16)
        output_onnx = "dinov2_flash.onnx" if use_flash else "dinov2_baseline.onnx"
    
        print(f"Exporting model to {output_onnx} with opset 23 and dynamo=True...")
    
        torch.onnx.export(
            wrapped_model,
            dummy_input,
            output_onnx,
            opset_version=23,
            dynamo=True,  # Enable new ONNX Exporter in Torch 2.7+
            input_names=["images"],
            output_names=["features"],
            dynamic_axes={
                "images": {0: "batch_size"},
                "features": {0: "batch_size"},
            },
            do_constant_folding=True,
            verbose=False
        )
    
        print(f"✓ Successfully exported to {output_onnx}")
  • Images/Logs demonstrating the issue:
    (Please include links to any images or log snippets demonstrating the non-fused ONNX graph pattern, and contrast with a profiling log from a different successful FMHA test if available.)

Steps To Reproduce

  1. Save the provided Python script as export_dinov2.py.
  2. Run the script to generate ONNX models:
    python export_dinov2.py
    # This generates 'dinov2_flash.onnx' (with enable_flash=True)
    python export_dinov2.py --disable_flash
    # This generates 'dinov2_baseline.onnx' (with enable_flash=False)
  3. Build TensorRT engines using trtexec with the generated ONNX models. (Example command, please adjust path to trtexec if not in PATH):
    trtexec --onnx=dinov2_flash.onnx \
      --saveEngine=dinov2_flash_fp16.plan \
      --fp16 \
      --builderOptimizationLevel=5 \
      --verbose --profilingVerbosity=detailed \
      --minShapes=images:1x3x224x224 \
      --optShapes=images:1x3x224x224 \
      --maxShapes=images:4x3x336x336 \
      --iterations=300 \
      --avgRuns=100 \
      --dumpProfile \
      --workspace=1000
  4. Review the trtexec output logs (especially with --profilingVerbosity=detailed).
  • Expected Behavior: A mha or fused_mha_v2 layer should be listed in the TensorRT graph/profile indicating FMHA fusion.
  • Observed Behavior: No such fused attention kernel is found in the logs for DINOv2 conversion. Instead, the attention operations remain as separate MatMul, Softmax, etc., layers.
  • Full traceback of errors encountered: N/A (conversion succeeds, but FMHA fusion is not observed, no error traceback).
  • Have you tried [the latest release]?: Yes, I am using TensorRT 10.8.0.43 (GA release), which should support FMHA.
  • Can this model run on other frameworks?: Yes, the ONNX models run successfully with ONNX Runtime (polygraphy run dinov2_flash.onnx --onnxrt).

Additional Context / Questions for NVIDIA:

I am seeking guidance on the best practices to ensure FMHA fusion for DINOv2-like Vision Transformer models when converting to TensorRT engines.

  1. FMHA applicability and setup: How can FMHA be correctly applied to DINOv2 models via ONNX to TensorRT? What specific ONNX graph patterns, PyTorch export flags, trtexec flags, or TensorRT Builder configurations for version 10.8.0.43 (or a recommended higher version) are necessary to guarantee FMHA fusion for DINOv2's attention layers?
  2. Performance Comparison Methodology: What is the recommended methodology for conducting a comprehensive performance comparison (Latency, Throughput, Memory Usage) for different precisions (FP32, FP16, BF16, INT8) and dynamic batch/sequence lengths? (e.g., trtexec flags like --avgRuns, --noDataTransfers, nsys profile for specific kernel timings).
  3. TensorRT Version Compatibility: If the TensorRT version I plan to use is incompatible, I would appreciate it if you could let me know which version I should test with! My ultimate goal is to perform inference and inspection of DINOv2-family models in C++ using the .trt engine.

Any insights or recommended strategies would be greatly appreciated.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Module:ONNXIssues relating to ONNX usage and import

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions