Skip to content

Commit ce580d3

Browse files
xinyazhangeqy
andauthored
[Release/2.6] Backport softmax fixes from 2.8dev (#2247)
This fixes OOB memory access for followng code ``` python import torch qk = torch.randn((9,1017), dtype=torch.float64, device='cuda') smqk = torch.softmax(qk, dim=-1) ``` Upstream PR: * pytorch#144009 * pytorch#154778 --------- Co-authored-by: eqy <[email protected]>
1 parent a854563 commit ce580d3

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

aten/src/ATen/native/cuda/SoftMax.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ ilpReduce(index_t shift,
468468
if (offset >= shift && offset < size) {
469469
threadVal = r(threadVal, data[offset]);
470470
}
471-
size -= blockDim.x;
471+
size -= blockDim.x > size ? size : blockDim.x;
472472
data += blockDim.x;
473473
}
474474
index_t last = size % (ILP * blockDim.x);
@@ -518,7 +518,7 @@ WriteFpropResultsVectorized(
518518
if (offset >= shift && offset < size) {
519519
output[offset] = epilogue(input[offset]);
520520
}
521-
size -= blockDim.x;
521+
size -= blockDim.x > size ? size : blockDim.x;
522522
input += blockDim.x;
523523
output += blockDim.x;
524524
}
@@ -573,7 +573,7 @@ WriteBpropResultsVectorized(
573573
if (threadIdx.x >= shift) {
574574
gradInput[offset] = epilogue(gradOutput[offset], output[offset]);
575575
}
576-
size -= blockDim.x;
576+
size -= blockDim.x > size ? size : blockDim.x;
577577
gradInput += blockDim.x;
578578
output += blockDim.x;
579579
gradOutput += blockDim.x;

test/test_nn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10498,6 +10498,13 @@ def run_test(*shape):
1049810498
run_test(1100000000, 2) # Illegal memory access https://github.com/pytorch/pytorch/issues/52715
1049910499
run_test(2200000000, 1) # invalid configuration argument https://github.com/pytorch/pytorch/issues/52716
1050010500

10501+
@onlyCUDA
10502+
@dtypes(torch.double)
10503+
def test_softmax_double(self, device, dtype):
10504+
logits = torch.randn(5, 513, dtype=dtype, device=device)
10505+
expected_ones = F.log_softmax(logits, dim=1).exp().sum(dim=1)
10506+
self.assertEqual(expected_ones, torch.ones_like(expected_ones))
10507+
1050110508
@onlyCUDA
1050210509
@dtypes(torch.half)
1050310510
@largeTensorTest("20GB")

0 commit comments

Comments
 (0)