@@ -72,11 +72,11 @@ def causal_conv1d_fwd_kernel(
7272
7373 if IS_VARLEN :
7474 i_n , i_t = tl .load (chunk_indices + i_t * 2 ).to (tl .int32 ), tl .load (chunk_indices + i_t * 2 + 1 ).to (tl .int32 )
75- bos , eos = tl .load (cu_seqlens + i_n ), tl .load (cu_seqlens + i_n + 1 )
75+ bos , eos = tl .load (cu_seqlens + i_n ). to ( tl . int64 ) , tl .load (cu_seqlens + i_n + 1 ). to ( tl . int64 )
7676 T = eos - bos
7777 else :
7878 i_n = i_b
79- bos , eos = i_b * T , i_b * T + T
79+ bos , eos = ( i_b * T ). to ( tl . int64 ), ( i_b * T + T ). to ( tl . int64 )
8080
8181 o_d = i_d * BD + tl .arange (0 , BD )
8282 o_w = tl .arange (0 , BW ) + W - BW
@@ -184,12 +184,12 @@ def causal_conv1d_bwd_kernel(
184184 if IS_VARLEN :
185185 i_tg = i_t
186186 i_n , i_t = tl .load (chunk_indices + i_t * 2 ).to (tl .int32 ), tl .load (chunk_indices + i_t * 2 + 1 ).to (tl .int32 )
187- bos , eos = tl .load (cu_seqlens + i_n ), tl .load (cu_seqlens + i_n + 1 )
187+ bos , eos = tl .load (cu_seqlens + i_n ). to ( tl . int64 ) , tl .load (cu_seqlens + i_n + 1 ). to ( tl . int64 )
188188 T = eos - bos
189189 else :
190190 i_tg = i_b * tl .num_programs (1 ) + i_t
191191 i_n = i_b
192- bos , eos = i_b * T , i_b * T + T
192+ bos , eos = ( i_b * T ). to ( tl . int64 ), ( i_b * T + T ). to ( tl . int64 )
193193
194194 o_d = i_d * BD + tl .arange (0 , BD )
195195 o_w = tl .arange (0 , BW ) + W - BW
@@ -544,10 +544,10 @@ def causal_conv1d_states_fwd_kernel(
544544):
545545 i_d , i_n = tl .program_id (0 ), tl .program_id (1 )
546546 if IS_VARLEN :
547- bos , eos = tl .load (cu_seqlens + i_n ), tl .load (cu_seqlens + i_n + 1 )
547+ bos , eos = tl .load (cu_seqlens + i_n ). to ( tl . int64 ) , tl .load (cu_seqlens + i_n + 1 ). to ( tl . int64 )
548548 T = eos - bos
549549 else :
550- bos , eos = i_n * T , i_n * T + T
550+ bos , eos = ( i_n * T ). to ( tl . int64 ), ( i_n * T + T ). to ( tl . int64 )
551551
552552 o_t = eos - BW + tl .arange (0 , BW )
553553 o_d = i_d * BD + tl .arange (0 , BD )
0 commit comments