Skip to content

Commit bf3cdb1

Browse files
support GDN packed sequence
1 parent 2b4b9c4 commit bf3cdb1

File tree

3 files changed

+145
-23
lines changed

3 files changed

+145
-23
lines changed

megatron/core/ssm/gated_delta_net.py

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -296,29 +296,71 @@ def forward(
296296
raise NotImplementedError("GDN does not support inference for now.")
297297

298298
if packed_seq_params is not None:
299-
# TODO: support packed sequence
300-
raise NotImplementedError("GDN does not support packed sequence for now.")
299+
assert batch == 1, "Packed sequence expects batch dimension to be 1"
300+
assert (
301+
not self.config.deterministic_mode
302+
), "Packed sequence does not support deterministic mode."
303+
304+
# Prefer cu_seqlens_q_padded if available, otherwise use cu_seqlens_q
305+
cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded or packed_seq_params.cu_seqlens_q
306+
# Prefer cu_seqlens_kv_padded if available, otherwise use cu_seqlens_kv
307+
cu_seqlens_kv = (
308+
packed_seq_params.cu_seqlens_kv_padded or packed_seq_params.cu_seqlens_kv
309+
)
310+
assert torch.equal(cu_seqlens_q, cu_seqlens_kv), (
311+
"Currently only support cu_seqlens_q equals to cu_seqlens_kv, "
312+
f"but got {cu_seqlens_q=} and {cu_seqlens_kv=}"
313+
)
314+
num_packed_seqs = cu_seqlens_q.shape[0] - 1
315+
assert num_packed_seqs > 0, (
316+
"Number of packed sequences must be greater than 0, "
317+
f"but got {cu_seqlens_q=} and {cu_seqlens_kv=}"
318+
)
319+
else:
320+
cu_seqlens_q = None
321+
cu_seqlens_kv = None
301322

302323
# Input projection
303324
nvtx_range_push(suffix="in_proj")
304325
qkvzba, _ = self.in_proj(hidden_states)
305326
nvtx_range_pop(suffix="in_proj")
306327

307328
# CP All to All: CP to HP
308-
qkvzba = tensor_a2a_cp2hp(
309-
qkvzba,
310-
seq_dim=0,
311-
head_dim=-1,
312-
cp_group=self.pg_collection.cp,
313-
split_sections=[
314-
self.qk_dim_local_tp,
315-
self.qk_dim_local_tp,
316-
self.v_dim_local_tp,
317-
self.v_dim_local_tp,
318-
self.num_value_heads // self.tp_size,
319-
self.num_value_heads // self.tp_size,
320-
],
321-
)
329+
if packed_seq_params is not None:
330+
unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens_q // self.cp_size, dim=0)
331+
outputs = []
332+
for qkvzba_i in unpacked_qkvzba:
333+
qkvzba_i = tensor_a2a_cp2hp(
334+
qkvzba_i,
335+
seq_dim=0,
336+
head_dim=-1,
337+
cp_group=self.pg_collection.cp,
338+
split_sections=[
339+
self.qk_dim_local_tp,
340+
self.qk_dim_local_tp,
341+
self.v_dim_local_tp,
342+
self.v_dim_local_tp,
343+
self.num_value_heads // self.tp_size,
344+
self.num_value_heads // self.tp_size,
345+
],
346+
)
347+
outputs.append(qkvzba_i)
348+
qkvzba = torch.cat(outputs, dim=0)
349+
else:
350+
qkvzba = tensor_a2a_cp2hp(
351+
qkvzba,
352+
seq_dim=0,
353+
head_dim=-1,
354+
cp_group=self.pg_collection.cp,
355+
split_sections=[
356+
self.qk_dim_local_tp,
357+
self.qk_dim_local_tp,
358+
self.v_dim_local_tp,
359+
self.v_dim_local_tp,
360+
self.num_value_heads // self.tp_size,
361+
self.num_value_heads // self.tp_size,
362+
],
363+
)
322364

323365
# Transpose: s b x --> b s x
324366
# From sbhd to bshd format
@@ -385,6 +427,7 @@ def forward(
385427
activation=self.activation,
386428
initial_state=None,
387429
output_final_state=False,
430+
cu_seqlens=cu_seqlens_q,
388431
)
389432
nvtx_range_pop(suffix="conv1d")
390433

@@ -440,6 +483,7 @@ def forward(
440483
initial_state=None,
441484
output_final_state=False,
442485
use_qk_l2norm_in_kernel=False,
486+
cu_seqlens=cu_seqlens_q,
443487
)
444488
nvtx_range_pop(suffix="gated_delta_rule")
445489

@@ -454,9 +498,19 @@ def forward(
454498
norm_out = norm_out.transpose(0, 1).contiguous()
455499

456500
# CP all to all: HP to CP
457-
norm_out = tensor_a2a_hp2cp(
458-
norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp
459-
)
501+
if packed_seq_params is not None:
502+
unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens_q, dim=0)
503+
outputs = []
504+
for norm_out_i in unpacked_norm_out:
505+
norm_out_i = tensor_a2a_hp2cp(
506+
norm_out_i, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp
507+
)
508+
outputs.append(norm_out_i)
509+
norm_out = torch.cat(outputs, dim=0)
510+
else:
511+
norm_out = tensor_a2a_hp2cp(
512+
norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp
513+
)
460514

461515
# Output projection
462516
nvtx_range_push(suffix="out_proj")
@@ -575,6 +629,17 @@ def _backward_out_proj(self):
575629
self.out_proj.backward_dw()
576630

577631

632+
def _unpack_sequence(x, cu_seqlens, dim=1):
633+
unpacked_x = []
634+
num_seqs = cu_seqlens.shape[0] - 1
635+
for i in range(num_seqs):
636+
idx_start = cu_seqlens[i].item()
637+
idx_end = cu_seqlens[i + 1].item()
638+
chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)]
639+
unpacked_x.append(x[chunked_index])
640+
return unpacked_x
641+
642+
578643
####################
579644
# Sharded state dict utilities
580645
####################

tests/unit_tests/ssm/test_gated_delta_net.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from tests.unit_tests.test_utilities import Utils
3434
from tests.unit_tests.transformer.test_attention import _test_parallel_attention_correctness
35+
from tests.unit_tests.transformer.test_multi_latent_attention import make_test_packed_seq_params
3536

3637
try:
3738
import fla
@@ -138,7 +139,51 @@ def test_gpu_forward(self):
138139
output.dtype == hidden_states.dtype
139140
), f"Output dtype {output.dtype=} mismatch with {hidden_states.dtype=}"
140141

142+
def test_gpu_forward_thd_correctness(self):
143+
if self.sp_size > 1:
144+
pytest.skip("Sequence parallel is not supported for this test case.")
141145

146+
atol, rtol = 3e-4, 3e-4
147+
148+
# Input shape
149+
sequence_length = 32
150+
micro_batch_size = 4
151+
cu_seqlens = [0, 32, 64, 96, 128]
152+
# sbhd input shape: [sequence length, batch size, hidden size]
153+
sub_sequence_length = sequence_length // self.cp_size
154+
hidden_states_sbhd = torch.rand(
155+
(sub_sequence_length, micro_batch_size, self.gdn.config.hidden_size)
156+
)
157+
attention_mask_sbhd = None
158+
hidden_states_sbhd = hidden_states_sbhd.cuda().bfloat16()
159+
# thd input shape: [sequence length * batch size, 1, hidden size]
160+
hidden_states_thd = hidden_states_sbhd.transpose(0, 1).contiguous()
161+
hidden_states_thd = hidden_states_thd.view(-1, 1, self.gdn.config.hidden_size)
162+
attention_mask_thd = None
163+
packed_seq_params = make_test_packed_seq_params(cu_seqlens=cu_seqlens)
164+
165+
# THD format
166+
output_thd, _ = self.gdn(
167+
hidden_states_thd, attention_mask_thd, packed_seq_params=packed_seq_params
168+
)
169+
# SBHD format
170+
output_sbhd, _ = self.gdn(hidden_states_sbhd, attention_mask_sbhd)
171+
output_sbhd_T = output_sbhd.transpose(0, 1).contiguous().view(*output_thd.shape)
172+
173+
rank = torch.distributed.get_rank()
174+
assert output_thd.shape[0] == sub_sequence_length * micro_batch_size
175+
assert output_thd.shape[1] == 1
176+
assert output_thd.shape[2] == self.gdn.config.hidden_size
177+
torch.testing.assert_close(
178+
output_sbhd_T,
179+
output_thd,
180+
atol=atol,
181+
rtol=rtol,
182+
msg=lambda msg: f"Output mismatch ({rank=}): {msg}",
183+
)
184+
185+
186+
@pytest.mark.parametrize("sequence_packing", [False, True])
142187
@pytest.mark.parametrize(
143188
("tp", "sp", "cp"),
144189
[
@@ -150,7 +195,7 @@ def test_gpu_forward(self):
150195
],
151196
)
152197
@pytest.mark.skipif(not HAVE_FLA, reason="FLA is not installed.")
153-
def test_parallel_gated_delta_net_correctness(tmp_path_dist_ckpt, tp, sp, cp):
198+
def test_parallel_gated_delta_net_correctness(tmp_path_dist_ckpt, sequence_packing, tp, sp, cp):
154199
transformer_config = TransformerConfig(
155200
hidden_size=128,
156201
linear_conv_kernel_dim=2,
@@ -191,4 +236,5 @@ def test_parallel_gated_delta_net_correctness(tmp_path_dist_ckpt, tp, sp, cp):
191236
seed=123,
192237
sequence_length=256,
193238
micro_batch_size=4,
239+
sequence_packing=sequence_packing,
194240
)

tests/unit_tests/transformer/test_attention.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
init_checkpointing_mock_args,
4141
)
4242
from tests.unit_tests.test_utilities import Utils
43+
from tests.unit_tests.transformer.test_multi_latent_attention import make_test_packed_seq_params
4344

4445
try:
4546
from transformer_engine.pytorch.attention.rope import apply_fused_qkv_rotary_pos_emb
@@ -710,6 +711,7 @@ def _test_parallel_attention_correctness(
710711
seed=123,
711712
sequence_length=256,
712713
micro_batch_size=4,
714+
sequence_packing=False,
713715
):
714716
# Model initialization function
715717
def initialize_gpt_model(
@@ -803,17 +805,24 @@ def initialize_gpt_model(
803805
def get_tensor_on_this_rank(tensor):
804806
if cp > 1:
805807
tensor = get_tensor_on_this_cp_rank(tensor, 0, cp_group)
808+
if sequence_packing:
809+
tensor = tensor.transpose(0, 1).contiguous().view(-1, 1, *tensor.shape[2:])
806810
if tp > 1 and sp:
807-
sp_seg = sequence_length // tp // cp
811+
sp_seg = tensor.shape[0] // tp
808812
tensor = tensor[tp_rank * sp_seg : (tp_rank + 1) * sp_seg]
809813
return tensor
810814

811815
# Calculate parallel model output
816+
if sequence_packing:
817+
cu_seqlens = [i * sequence_length for i in range(micro_batch_size + 1)]
818+
packed_seq_params = make_test_packed_seq_params(cu_seqlens=cu_seqlens)
819+
else:
820+
packed_seq_params = None
812821
input_hidden_states = get_tensor_on_this_rank(input_hidden_states)
813822
input_hidden_states = input_hidden_states.detach().requires_grad_(True)
814823
parallel_attention = gpt_model[0].decoder.layers[0].self_attention
815824
output_hidden_states_parallel, bias_hidden_states_parallel = parallel_attention(
816-
input_hidden_states, attention_mask=None
825+
input_hidden_states, attention_mask=None, packed_seq_params=packed_seq_params
817826
)
818827
output_hidden_states_parallel.sum().backward()
819828
input_grad_parallel = input_hidden_states.grad.detach()
@@ -879,6 +888,7 @@ def get_tensor_on_this_rank(tensor):
879888

880889

881890
# TODO(yuzhongw): Add test case for fallback_to_eager_attn
891+
@pytest.mark.parametrize("sequence_packing", [False, True])
882892
@pytest.mark.parametrize("apply_rope_fusion", [False, True])
883893
@pytest.mark.parametrize(
884894
("tp", "sp", "cp"),
@@ -893,7 +903,7 @@ def get_tensor_on_this_rank(tensor):
893903
@pytest.mark.parametrize("qk_layernorm", [False, True])
894904
@pytest.mark.parametrize("output_gate", [False, True])
895905
def test_parallel_attention_correctness(
896-
tmp_path_dist_ckpt, apply_rope_fusion, tp, sp, cp, qk_layernorm, output_gate
906+
tmp_path_dist_ckpt, sequence_packing, apply_rope_fusion, tp, sp, cp, qk_layernorm, output_gate
897907
):
898908
transformer_config = TransformerConfig(
899909
num_layers=1,
@@ -922,6 +932,7 @@ def test_parallel_attention_correctness(
922932
cp=cp,
923933
seed=123,
924934
sequence_length=256,
935+
sequence_packing=sequence_packing,
925936
)
926937

927938

0 commit comments

Comments
 (0)