-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Description
Description
Compiling an Attention layer which uses torch's scaled_dot_product_attention to a TRT engine results in incorrect outputs.
Environment
Tested against nvcr.io/nvidia/pytorch:24.12-py3 on an L4 GPU
TensorRT Version: 10.7.0.23
NVIDIA GPU: L4
NVIDIA Driver Version:
CUDA Version: 12.6.3
CUDNN Version: 9.6.0.74
Operating System: Ubuntu 24.04
Python Version (if applicable): 3.12
Tensorflow Version (if applicable):
PyTorch Version (if applicable): 2.6.0a0+df5bbc09d1.
Baremetal or Container (if so, version): Container, nvcr.io/nvidia/pytorch:24.12-py3
Relevant Files
Model link:
In high level, we have:
class AttentionUsingScaledDotProduct(nn.Module):
...
def forward(self, x):
...
x = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
scale=self.scale,
)
...
class ExplicitAttention(nn.Module):
...
def forward(self, x):
...
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
...
And the first version produces an incorrect TRT engine.
This is an image of comparing both ONNX files:
explicit version (left) and the scaled dot product version (right)
Interestingly, patching the bad ONNX to have B=0.125 on the right node (and B=1 on the left) fixes the issue, but not the other way around.
Steps To Reproduce
python min_repro.py
Prints:
Torch: [explicit<->sdpa] Is allclose? True
Torch: [explicit<->mha_fwd] Is allclose? True
Torch: [explicit<->sdpa] Total difference: tensor(0.1525)
Torch: [explicit<->mha_fwd] Total difference: tensor(0.1528)
...
TRT: [explicit<->sdpa] Is allclose? False
TRT: [explicit<->sdpa] Total difference: tensor(1373.3945, device='cuda:0')
TRT: [explicit<->mha_fwd] Is allclose? True
TRT: [explicit<->mha_fwd] Total difference: tensor(0.2232, device='cuda:0')
TRT: Explicit Attention: tensor([-0.0924, 0.0904, -0.3557, -0.4190, 0.0607, 0.2370, 0.0293, -0.2666,
-0.1009, 0.0399, -0.2203, -0.1261, 0.2094, 0.2609, -0.1429, -0.0457,
-0.0566, -0.2337, -0.0609, 0.0572, 0.1024, -0.1487, 0.2335, -0.0703,
-0.2909, 0.0832, -0.1907, 0.0462, -0.3819, -0.2341, 0.1027, 0.2187],
device='cuda:0')
TRT: Scaled Dot Product Attention: tensor([-0.0920, 0.0905, -0.3552, -0.4191, 0.0605, 0.2372, 0.0291, -0.2665,
-0.1009, 0.0400, -0.2203, -0.1254, 0.2096, 0.2612, -0.1429, -0.0454,
-0.0563, -0.2348, -0.0605, 0.0571, 0.1028, -0.1491, 0.2334, -0.0706,
-0.2912, 0.0824, -0.1899, 0.0453, -0.3816, -0.2335, 0.1030, 0.2190],
device='cuda:0')
TRT: MHA Forward: tensor([-0.0924, 0.0904, -0.3557, -0.4190, 0.0607, 0.2370, 0.0293, -0.2666,
-0.1009, 0.0399, -0.2203, -0.1261, 0.2094, 0.2609, -0.1429, -0.0457,
-0.0566, -0.2337, -0.0609, 0.0572, 0.1024, -0.1487, 0.2335, -0.0703,
-0.2909, 0.0832, -0.1907, 0.0462, -0.3819, -0.2341, 0.1027, 0.2187],
device='cuda:0')
Commands or scripts:
Have you tried the latest release?:
Yes
Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt):
Yes, running using onnx runtime works as expected.

