Skip to content

Commit 17baeab

Browse files
committed
fixes
1 parent 9aa7815 commit 17baeab

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

accelerated_scan/warp.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,8 @@ __device__ Tuple load_shifted_tuple(const Tuple* ptr, int index, int limit) {
356356
const int idx = index * Tuple::Size + i + offset;
357357
if (idx >= 0 && idx < limit * Tuple::Size) {
358358
x.data[i] = rawPtr[offset];
359+
} else {
360+
x.data[i] = 0;
359361
}
360362
}
361363

accelerated_scan/warp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
cpp_source = """
99
at::Tensor warpscan_forward(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse);
10+
void warpscan_backward(const at::Tensor &gates, const at::Tensor &output, const at::Tensor &outGrad, const at::Tensor& gateGradOut, const at::Tensor& valueGradOut);
1011
"""
1112

1213
module = load_inline(

0 commit comments

Comments
 (0)