Skip to content

Commit 194539e

Browse files
jhavukainenpytorchmergebot
authored andcommitted
Address NaNs if SDPA is called with all values masked from query (pytorch#157727)
Fixes pytorch#156707 Detect if all values along the softmax axis are infs and overwrite the outputs for those computations with zeros before the final matmul. The behavior should be aligned with the CPU implementation. These types of cases where all values along the dimension in the attention mask are false leading to the undefined outputs in softmax occur with left padded batches for generation in HF transformers according to the original issue. Pull Request resolved: pytorch#157727 Approved by: https://github.com/malfet
1 parent bcf5063 commit 194539e

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

aten/src/ATen/native/mps/operations/Attention.mm

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,22 @@
114114
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
115115
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil];
116116
}
117+
118+
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
119+
// Overwrites expected NANs in sm with zeros.
120+
auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType];
121+
auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil];
122+
auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil];
123+
auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil];
124+
auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType];
125+
117126
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];
118-
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:sm secondaryTensor:vTensor name:nil];
127+
MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask
128+
truePredicateTensor:zeroTensor
129+
falsePredicateTensor:sm
130+
name:nil];
131+
132+
auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil];
119133
graph->qTensor = qTensor;
120134
graph->kTensor = kTensor;
121135
graph->vTensor = vTensor;

test/test_mps.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9257,6 +9257,18 @@ def test_sdpa_mask_fp16_L6(self):
92579257
def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self):
92589258
self._test_sdpa_mask(torch.float16, 7, 17, 23, 121)
92599259

9260+
# Regression test from: https://github.com/pytorch/pytorch/issues/156707
9261+
@parametrize("dtype", [torch.float16, torch.float32])
9262+
def test_sdpa_full_mask(self, dtype):
9263+
q = torch.randn(1, 1, 2, 4, dtype=dtype)
9264+
k = torch.randn(1, 1, 2, 4, dtype=dtype)
9265+
v = torch.randn(1, 1, 2, 4, dtype=dtype)
9266+
mask = torch.tensor([[[[False, False], [True, True]]]], dtype=torch.bool)
9267+
9268+
out_cpu = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
9269+
out_mps = F.scaled_dot_product_attention(q.to('mps'), k.to('mps'), v.to('mps'), attn_mask=mask.to('mps'))
9270+
self._compare_tensors(out_mps.cpu(), out_cpu)
9271+
92609272
@parametrize("dtype", [torch.float16, torch.float32])
92619273
def test_sdpa_3d_input(self, dtype):
92629274
head_num, seq_len, embed_dim = 16, 16, 80

0 commit comments

Comments
 (0)