Skip to content

Commit 83a8607

Browse files
Fix several bugs
Co-authored-by: kunlunl <kunlunl@users.noreply.github.com>
1 parent ae8806c commit 83a8607

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

megatron/core/ssm/gated_delta_net.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,15 @@ def forward(
302302
), "Packed sequence does not support deterministic mode."
303303

304304
# Prefer cu_seqlens_q_padded if available, otherwise use cu_seqlens_q
305-
cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded or packed_seq_params.cu_seqlens_q
305+
if packed_seq_params.cu_seqlens_q_padded is not None:
306+
cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded
307+
else:
308+
cu_seqlens_q = packed_seq_params.cu_seqlens_q
306309
# Prefer cu_seqlens_kv_padded if available, otherwise use cu_seqlens_kv
307-
cu_seqlens_kv = (
308-
packed_seq_params.cu_seqlens_kv_padded or packed_seq_params.cu_seqlens_kv
309-
)
310+
if packed_seq_params.cu_seqlens_kv_padded is not None:
311+
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded
312+
else:
313+
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
310314
assert torch.equal(cu_seqlens_q, cu_seqlens_kv), (
311315
"Currently only support cu_seqlens_q equals to cu_seqlens_kv, "
312316
f"but got {cu_seqlens_q=} and {cu_seqlens_kv=}"
@@ -624,7 +628,7 @@ def _unpack_sequence(x, cu_seqlens, dim=1):
624628
idx_start = cu_seqlens[i].item()
625629
idx_end = cu_seqlens[i + 1].item()
626630
chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)]
627-
unpacked_x.append(x[chunked_index])
631+
unpacked_x.append(x[tuple(chunked_index)])
628632
return unpacked_x
629633

630634

@@ -879,6 +883,7 @@ def torch_chunk_gated_delta_rule(
879883
initial_state=None,
880884
output_final_state=False,
881885
use_qk_l2norm_in_kernel=False,
886+
cu_seqlens=None,
882887
):
883888
# pylint: disable=line-too-long
884889
'''
@@ -888,6 +893,10 @@ def torch_chunk_gated_delta_rule(
888893
Reference: https://github.com/huggingface/transformers/blob/144c8ce2809a2e21914017652700e1ecb450501e/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L470-L547
889894
'''
890895

896+
assert cu_seqlens is None, (
897+
"cu_seqlens is not supported for torch_chunk_gated_delta_rule for now."
898+
)
899+
891900
initial_dtype = query.dtype
892901
if use_qk_l2norm_in_kernel:
893902
query = l2norm(query, dim=-1, eps=1e-6)

0 commit comments

Comments
 (0)