Skip to content

Commit f321568

Browse files
authored
[Conv] Fix potential OOB problems (#615)
1 parent 7b9ec6e commit f321568

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

fla/modules/convolution.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ def causal_conv1d_fwd_kernel(
7272

7373
if IS_VARLEN:
7474
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
75-
bos, eos = tl.load(cu_seqlens + i_n), tl.load(cu_seqlens + i_n + 1)
75+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
7676
T = eos - bos
7777
else:
7878
i_n = i_b
79-
bos, eos = i_b * T, i_b * T + T
79+
bos, eos = (i_b * T).to(tl.int64), (i_b * T + T).to(tl.int64)
8080

8181
o_d = i_d * BD + tl.arange(0, BD)
8282
o_w = tl.arange(0, BW) + W - BW
@@ -184,12 +184,12 @@ def causal_conv1d_bwd_kernel(
184184
if IS_VARLEN:
185185
i_tg = i_t
186186
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
187-
bos, eos = tl.load(cu_seqlens + i_n), tl.load(cu_seqlens + i_n + 1)
187+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
188188
T = eos - bos
189189
else:
190190
i_tg = i_b * tl.num_programs(1) + i_t
191191
i_n = i_b
192-
bos, eos = i_b * T, i_b * T + T
192+
bos, eos = (i_b * T).to(tl.int64), (i_b * T + T).to(tl.int64)
193193

194194
o_d = i_d * BD + tl.arange(0, BD)
195195
o_w = tl.arange(0, BW) + W - BW
@@ -544,10 +544,10 @@ def causal_conv1d_states_fwd_kernel(
544544
):
545545
i_d, i_n = tl.program_id(0), tl.program_id(1)
546546
if IS_VARLEN:
547-
bos, eos = tl.load(cu_seqlens + i_n), tl.load(cu_seqlens + i_n + 1)
547+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
548548
T = eos - bos
549549
else:
550-
bos, eos = i_n * T, i_n * T + T
550+
bos, eos = (i_n * T).to(tl.int64), (i_n * T + T).to(tl.int64)
551551

552552
o_t = eos - BW + tl.arange(0, BW)
553553
o_d = i_d * BD + tl.arange(0, BD)

0 commit comments

Comments
 (0)