Skip to content

Commit e94395d

Browse files
Fix several bugs
Signed-off-by: yuzhongw <yuzhongw@nvidia.com> Co-authored-by: kunlunl <kunlunl@nvidia.com>
1 parent bf3cdb1 commit e94395d

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

megatron/core/ssm/gated_delta_net.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,15 @@ def forward(
302302
), "Packed sequence does not support deterministic mode."
303303

304304
# Prefer cu_seqlens_q_padded if available, otherwise use cu_seqlens_q
305-
cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded or packed_seq_params.cu_seqlens_q
305+
if packed_seq_params.cu_seqlens_q_padded is not None:
306+
cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded
307+
else:
308+
cu_seqlens_q = packed_seq_params.cu_seqlens_q
306309
# Prefer cu_seqlens_kv_padded if available, otherwise use cu_seqlens_kv
307-
cu_seqlens_kv = (
308-
packed_seq_params.cu_seqlens_kv_padded or packed_seq_params.cu_seqlens_kv
309-
)
310+
if packed_seq_params.cu_seqlens_kv_padded is not None:
311+
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded
312+
else:
313+
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
310314
assert torch.equal(cu_seqlens_q, cu_seqlens_kv), (
311315
"Currently only support cu_seqlens_q equals to cu_seqlens_kv, "
312316
f"but got {cu_seqlens_q=} and {cu_seqlens_kv=}"
@@ -636,7 +640,7 @@ def _unpack_sequence(x, cu_seqlens, dim=1):
636640
idx_start = cu_seqlens[i].item()
637641
idx_end = cu_seqlens[i + 1].item()
638642
chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)]
639-
unpacked_x.append(x[chunked_index])
643+
unpacked_x.append(x[tuple(chunked_index)])
640644
return unpacked_x
641645

642646

@@ -891,6 +895,7 @@ def torch_chunk_gated_delta_rule(
891895
initial_state=None,
892896
output_final_state=False,
893897
use_qk_l2norm_in_kernel=False,
898+
cu_seqlens=None,
894899
):
895900
# pylint: disable=line-too-long
896901
'''
@@ -900,6 +905,10 @@ def torch_chunk_gated_delta_rule(
900905
Reference: https://github.com/huggingface/transformers/blob/144c8ce2809a2e21914017652700e1ecb450501e/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L470-L547
901906
'''
902907

908+
assert cu_seqlens is None, (
909+
"cu_seqlens is not supported for torch_chunk_gated_delta_rule for now."
910+
)
911+
903912
initial_dtype = query.dtype
904913
if use_qk_l2norm_in_kernel:
905914
query = l2norm(query, dim=-1, eps=1e-6)

0 commit comments

Comments
 (0)