Skip to content

Incorrect outputs of TensorRT 10.7 when compiling F.scaled_dot_product_attention on GPU L4 #4333

@ohadravid

Description

@ohadravid

Description

Compiling an Attention layer which uses torch's scaled_dot_product_attention to a TRT engine results in incorrect outputs.

Environment

Tested against nvcr.io/nvidia/pytorch:24.12-py3 on an L4 GPU

TensorRT Version: 10.7.0.23

NVIDIA GPU: L4

NVIDIA Driver Version:

CUDA Version: 12.6.3

CUDNN Version: 9.6.0.74

Operating System: Ubuntu 24.04

Python Version (if applicable): 3.12

Tensorflow Version (if applicable):

PyTorch Version (if applicable): 2.6.0a0+df5bbc09d1.

Baremetal or Container (if so, version): Container, nvcr.io/nvidia/pytorch:24.12-py3

Relevant Files

Model link:

A gist to reproduce the issue

In high level, we have:

class AttentionUsingScaledDotProduct(nn.Module):
    ...
    def forward(self, x):
        ...
        x = F.scaled_dot_product_attention(
            q,
            k,
            v,
            dropout_p=self.attn_drop.p if self.training else 0.0,
            scale=self.scale,
        )
        ...


class ExplicitAttention(nn.Module):
    ...
    def forward(self, x):
        ...
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
        ...

And the first version produces an incorrect TRT engine.

This is an image of comparing both ONNX files:
explicit version (left) and the scaled dot product version (right)

Image

Interestingly, patching the bad ONNX to have B=0.125 on the right node (and B=1 on the left) fixes the issue, but not the other way around.

Image

Steps To Reproduce

python min_repro.py

Prints:

Torch: [explicit<->sdpa] Is allclose? True
Torch: [explicit<->mha_fwd] Is allclose? True
Torch: [explicit<->sdpa] Total difference: tensor(0.1525)
Torch: [explicit<->mha_fwd] Total difference: tensor(0.1528)
...
TRT: [explicit<->sdpa] Is allclose? False
TRT: [explicit<->sdpa] Total difference: tensor(1373.3945, device='cuda:0')
TRT: [explicit<->mha_fwd] Is allclose? True
TRT: [explicit<->mha_fwd] Total difference: tensor(0.2232, device='cuda:0')
TRT: Explicit Attention: tensor([-0.0924,  0.0904, -0.3557, -0.4190,  0.0607,  0.2370,  0.0293, -0.2666,
        -0.1009,  0.0399, -0.2203, -0.1261,  0.2094,  0.2609, -0.1429, -0.0457,
        -0.0566, -0.2337, -0.0609,  0.0572,  0.1024, -0.1487,  0.2335, -0.0703,
        -0.2909,  0.0832, -0.1907,  0.0462, -0.3819, -0.2341,  0.1027,  0.2187],
       device='cuda:0')
TRT: Scaled Dot Product Attention: tensor([-0.0920,  0.0905, -0.3552, -0.4191,  0.0605,  0.2372,  0.0291, -0.2665,
        -0.1009,  0.0400, -0.2203, -0.1254,  0.2096,  0.2612, -0.1429, -0.0454,
        -0.0563, -0.2348, -0.0605,  0.0571,  0.1028, -0.1491,  0.2334, -0.0706,
        -0.2912,  0.0824, -0.1899,  0.0453, -0.3816, -0.2335,  0.1030,  0.2190],
       device='cuda:0')
TRT: MHA Forward: tensor([-0.0924,  0.0904, -0.3557, -0.4190,  0.0607,  0.2370,  0.0293, -0.2666,
        -0.1009,  0.0399, -0.2203, -0.1261,  0.2094,  0.2609, -0.1429, -0.0457,
        -0.0566, -0.2337, -0.0609,  0.0572,  0.1024, -0.1487,  0.2335, -0.0703,
        -0.2909,  0.0832, -0.1907,  0.0462, -0.3819, -0.2341,  0.1027,  0.2187],
       device='cuda:0')

Commands or scripts:

Have you tried the latest release?:

Yes

Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt):

Yes, running using onnx runtime works as expected.

Metadata

Metadata

Assignees

Labels

Module:AccuracyOutput mismatch between TensorRT and other frameworksinternal-bug-trackedTracked internally, will be fixed in a future release.triagedIssue has been triaged by maintainers

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions