@@ -302,11 +302,15 @@ def forward(
302302 ), "Packed sequence does not support deterministic mode."
303303
304304 # Prefer cu_seqlens_q_padded if available, otherwise use cu_seqlens_q
305- cu_seqlens_q = packed_seq_params .cu_seqlens_q_padded or packed_seq_params .cu_seqlens_q
305+ if packed_seq_params .cu_seqlens_q_padded is not None :
306+ cu_seqlens_q = packed_seq_params .cu_seqlens_q_padded
307+ else :
308+ cu_seqlens_q = packed_seq_params .cu_seqlens_q
306309 # Prefer cu_seqlens_kv_padded if available, otherwise use cu_seqlens_kv
307- cu_seqlens_kv = (
308- packed_seq_params .cu_seqlens_kv_padded or packed_seq_params .cu_seqlens_kv
309- )
310+ if packed_seq_params .cu_seqlens_kv_padded is not None :
311+ cu_seqlens_kv = packed_seq_params .cu_seqlens_kv_padded
312+ else :
313+ cu_seqlens_kv = packed_seq_params .cu_seqlens_kv
310314 assert torch .equal (cu_seqlens_q , cu_seqlens_kv ), (
311315 "Currently only support cu_seqlens_q equals to cu_seqlens_kv, "
312316 f"but got { cu_seqlens_q = } and { cu_seqlens_kv = } "
@@ -624,7 +628,7 @@ def _unpack_sequence(x, cu_seqlens, dim=1):
624628 idx_start = cu_seqlens [i ].item ()
625629 idx_end = cu_seqlens [i + 1 ].item ()
626630 chunked_index = [slice (None )] * dim + [slice (idx_start , idx_end )]
627- unpacked_x .append (x [chunked_index ])
631+ unpacked_x .append (x [tuple ( chunked_index ) ])
628632 return unpacked_x
629633
630634
@@ -879,6 +883,7 @@ def torch_chunk_gated_delta_rule(
879883 initial_state = None ,
880884 output_final_state = False ,
881885 use_qk_l2norm_in_kernel = False ,
886+ cu_seqlens = None ,
882887):
883888 # pylint: disable=line-too-long
884889 '''
@@ -888,6 +893,10 @@ def torch_chunk_gated_delta_rule(
888893 Reference: https://github.com/huggingface/transformers/blob/144c8ce2809a2e21914017652700e1ecb450501e/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L470-L547
889894 '''
890895
896+ assert cu_seqlens is None , (
897+ "cu_seqlens is not supported for torch_chunk_gated_delta_rule for now."
898+ )
899+
891900 initial_dtype = query .dtype
892901 if use_qk_l2norm_in_kernel :
893902 query = l2norm (query , dim = - 1 , eps = 1e-6 )
0 commit comments