@@ -302,11 +302,15 @@ def forward(
302302 ), "Packed sequence does not support deterministic mode."
303303
304304 # Prefer cu_seqlens_q_padded if available, otherwise use cu_seqlens_q
305- cu_seqlens_q = packed_seq_params .cu_seqlens_q_padded or packed_seq_params .cu_seqlens_q
305+ if packed_seq_params .cu_seqlens_q_padded is not None :
306+ cu_seqlens_q = packed_seq_params .cu_seqlens_q_padded
307+ else :
308+ cu_seqlens_q = packed_seq_params .cu_seqlens_q
306309 # Prefer cu_seqlens_kv_padded if available, otherwise use cu_seqlens_kv
307- cu_seqlens_kv = (
308- packed_seq_params .cu_seqlens_kv_padded or packed_seq_params .cu_seqlens_kv
309- )
310+ if packed_seq_params .cu_seqlens_kv_padded is not None :
311+ cu_seqlens_kv = packed_seq_params .cu_seqlens_kv_padded
312+ else :
313+ cu_seqlens_kv = packed_seq_params .cu_seqlens_kv
310314 assert torch .equal (cu_seqlens_q , cu_seqlens_kv ), (
311315 "Currently only support cu_seqlens_q equals to cu_seqlens_kv, "
312316 f"but got { cu_seqlens_q = } and { cu_seqlens_kv = } "
@@ -636,7 +640,7 @@ def _unpack_sequence(x, cu_seqlens, dim=1):
636640 idx_start = cu_seqlens [i ].item ()
637641 idx_end = cu_seqlens [i + 1 ].item ()
638642 chunked_index = [slice (None )] * dim + [slice (idx_start , idx_end )]
639- unpacked_x .append (x [chunked_index ])
643+ unpacked_x .append (x [tuple ( chunked_index ) ])
640644 return unpacked_x
641645
642646
@@ -891,6 +895,7 @@ def torch_chunk_gated_delta_rule(
891895 initial_state = None ,
892896 output_final_state = False ,
893897 use_qk_l2norm_in_kernel = False ,
898+ cu_seqlens = None ,
894899):
895900 # pylint: disable=line-too-long
896901 '''
@@ -900,6 +905,10 @@ def torch_chunk_gated_delta_rule(
900905 Reference: https://github.com/huggingface/transformers/blob/144c8ce2809a2e21914017652700e1ecb450501e/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L470-L547
901906 '''
902907
908+ assert cu_seqlens is None , (
909+ "cu_seqlens is not supported for torch_chunk_gated_delta_rule for now."
910+ )
911+
903912 initial_dtype = query .dtype
904913 if use_qk_l2norm_in_kernel :
905914 query = l2norm (query , dim = - 1 , eps = 1e-6 )
0 commit comments