File tree Expand file tree Collapse file tree 2 files changed +10
-3
lines changed
aten/src/ATen/native/cuda Expand file tree Collapse file tree 2 files changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -468,7 +468,7 @@ ilpReduce(index_t shift,
468468 if (offset >= shift && offset < size) {
469469 threadVal = r (threadVal, data[offset]);
470470 }
471- size -= blockDim .x ;
471+ size -= blockDim .x > size ? size : blockDim . x ;
472472 data += blockDim .x ;
473473 }
474474 index_t last = size % (ILP * blockDim .x );
@@ -518,7 +518,7 @@ WriteFpropResultsVectorized(
518518 if (offset >= shift && offset < size) {
519519 output[offset] = epilogue (input[offset]);
520520 }
521- size -= blockDim .x ;
521+ size -= blockDim .x > size ? size : blockDim . x ;
522522 input += blockDim .x ;
523523 output += blockDim .x ;
524524 }
@@ -573,7 +573,7 @@ WriteBpropResultsVectorized(
573573 if (threadIdx .x >= shift) {
574574 gradInput[offset] = epilogue (gradOutput[offset], output[offset]);
575575 }
576- size -= blockDim .x ;
576+ size -= blockDim .x > size ? size : blockDim . x ;
577577 gradInput += blockDim .x ;
578578 output += blockDim .x ;
579579 gradOutput += blockDim .x ;
Original file line number Diff line number Diff line change @@ -10498,6 +10498,13 @@ def run_test(*shape):
1049810498 run_test(1100000000, 2) # Illegal memory access https://github.com/pytorch/pytorch/issues/52715
1049910499 run_test(2200000000, 1) # invalid configuration argument https://github.com/pytorch/pytorch/issues/52716
1050010500
10501+ @onlyCUDA
10502+ @dtypes(torch.double)
10503+ def test_softmax_double(self, device, dtype):
10504+ logits = torch.randn(5, 513, dtype=dtype, device=device)
10505+ expected_ones = F.log_softmax(logits, dim=1).exp().sum(dim=1)
10506+ self.assertEqual(expected_ones, torch.ones_like(expected_ones))
10507+
1050110508 @onlyCUDA
1050210509 @dtypes(torch.half)
1050310510 @largeTensorTest("20GB")
You can’t perform that action at this time.
0 commit comments