Skip to content

Commit 73d512d

Browse files
support GDN packed sequence
1 parent 6e2153b commit 73d512d

File tree

2 files changed

+222
-71
lines changed

2 files changed

+222
-71
lines changed

megatron/core/ssm/gated_delta_net.py

Lines changed: 159 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -293,29 +293,64 @@ def forward(
293293
raise NotImplementedError("GDN does not support inference for now.")
294294

295295
if packed_seq_params is not None:
296-
# TODO: support packed sequence
297-
raise NotImplementedError("GDN does not support packed sequence for now.")
296+
assert batch == 1, "Packed sequence expects batch dimension to be 1"
297+
# Prefer cu_seqlens_q_padded if available, otherwise use cu_seqlens_q
298+
cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded or packed_seq_params.cu_seqlens_q
299+
# Prefer cu_seqlens_kv_padded if available, otherwise use cu_seqlens_kv
300+
cu_seqlens_kv = (
301+
packed_seq_params.cu_seqlens_kv_padded or packed_seq_params.cu_seqlens_kv
302+
)
303+
assert torch.equal(cu_seqlens_q, cu_seqlens_kv), (
304+
"Currently only support cu_seqlens_q equals to cu_seqlens_kv, "
305+
f"but got {cu_seqlens_q=} and {cu_seqlens_kv=}"
306+
)
307+
num_packed_seqs = cu_seqlens_q.shape[0] - 1
308+
assert num_packed_seqs > 0, (
309+
"Number of packed sequences must be greater than 0, "
310+
f"but got {cu_seqlens_q=} and {cu_seqlens_kv=}"
311+
)
298312

299313
# Input projection
300314
nvtx_range_push(suffix="in_proj")
301315
qkvzba, _ = self.in_proj(hidden_states)
302316
nvtx_range_pop(suffix="in_proj")
303317

304318
# CP All to All: CP to HP
305-
qkvzba = tensor_a2a_cp2hp(
306-
qkvzba,
307-
seq_dim=0,
308-
head_dim=-1,
309-
cp_group=self.pg_collection.cp,
310-
split_sections=[
311-
self.qk_dim_local_tp,
312-
self.qk_dim_local_tp,
313-
self.v_dim_local_tp,
314-
self.v_dim_local_tp,
315-
self.num_value_heads // self.tp_size,
316-
self.num_value_heads // self.tp_size,
317-
],
318-
)
319+
if packed_seq_params is not None:
320+
unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens_q // self.cp_size, dim=0)
321+
outputs = []
322+
for qkvzba_i in unpacked_qkvzba:
323+
qkvzba_i = tensor_a2a_cp2hp(
324+
qkvzba_i,
325+
seq_dim=0,
326+
head_dim=-1,
327+
cp_group=self.pg_collection.cp,
328+
split_sections=[
329+
self.qk_dim_local_tp,
330+
self.qk_dim_local_tp,
331+
self.v_dim_local_tp,
332+
self.v_dim_local_tp,
333+
self.num_value_heads // self.tp_size,
334+
self.num_value_heads // self.tp_size,
335+
],
336+
)
337+
outputs.append(qkvzba_i)
338+
qkvzba = torch.cat(outputs, dim=0)
339+
else:
340+
qkvzba = tensor_a2a_cp2hp(
341+
qkvzba,
342+
seq_dim=0,
343+
head_dim=-1,
344+
cp_group=self.pg_collection.cp,
345+
split_sections=[
346+
self.qk_dim_local_tp,
347+
self.qk_dim_local_tp,
348+
self.v_dim_local_tp,
349+
self.v_dim_local_tp,
350+
self.num_value_heads // self.tp_size,
351+
self.num_value_heads // self.tp_size,
352+
],
353+
)
319354

320355
# Transpose: s b x --> b s x
321356
# From sbhd to bshd format
@@ -337,51 +372,18 @@ def forward(
337372
alpha = alpha.reshape(batch, seq_len, -1)
338373

339374
# Convolution on qkv
340-
qkv = qkv.transpose(1, 2).contiguous() # b, s, d -> b, d, s
341375
nvtx_range_push(suffix="conv1d")
342-
qkv_channels_split_sections = [
343-
self.qk_dim_local_tp,
344-
self.qk_dim_local_tp,
345-
self.v_dim_local_tp,
346-
]
347-
conv1d_weight = get_parameter_local_cp(
348-
self.conv1d.weight,
349-
dim=0,
350-
cp_group=self.pg_collection.cp,
351-
split_sections=qkv_channels_split_sections,
352-
)
353-
conv1d_bias = (
354-
get_parameter_local_cp(
355-
self.conv1d.bias,
356-
dim=0,
357-
cp_group=self.pg_collection.cp,
358-
split_sections=qkv_channels_split_sections,
359-
)
360-
if self.conv_bias
361-
else None
362-
)
363-
if (causal_conv1d_fn is None) or self.config.deterministic_mode:
364-
conv_out = F.conv1d(
365-
input=qkv,
366-
weight=conv1d_weight,
367-
bias=conv1d_bias,
368-
stride=self.conv1d.stride,
369-
padding=self.conv1d.padding,
370-
dilation=self.conv1d.dilation,
371-
groups=self.conv_dim_local_tp // self.cp_size,
372-
)
373-
qkv = self.act_fn(conv_out[..., :seq_len])
376+
if packed_seq_params is not None:
377+
unpacked_qkv = _unpack_sequence(qkv, cu_seqlens_q)
378+
outputs = []
379+
for qkv_i in unpacked_qkv:
380+
qkv_i = self._conv1d_on_qkv(qkv_i)
381+
outputs.append(qkv_i)
382+
qkv = torch.cat(outputs, dim=1)
374383
else:
375-
assert self.activation in ["silu", "swish"]
376-
qkv = causal_conv1d_fn(
377-
x=qkv,
378-
weight=conv1d_weight.squeeze(1), # d, 1, w -> d, w
379-
bias=conv1d_bias,
380-
activation=self.activation,
381-
)
384+
qkv = self._conv1d_on_qkv(qkv)
382385
nvtx_range_pop(suffix="conv1d")
383386
# Split qkv into query, key, and value
384-
qkv = qkv.transpose(1, 2) # b, d, s -> b, s, d
385387
query, key, value = torch.split(
386388
qkv,
387389
[
@@ -422,18 +424,36 @@ def forward(
422424

423425
nvtx_range_push(suffix="gated_delta_rule")
424426
if self.config.deterministic_mode:
425-
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule(
426-
query,
427-
key,
428-
value,
429-
g=g,
430-
beta=beta,
431-
initial_state=None,
432-
output_final_state=False,
433-
use_qk_l2norm_in_kernel=False,
434-
)
427+
gated_delta_rule_fn = torch_chunk_gated_delta_rule
428+
else:
429+
gated_delta_rule_fn = chunk_gated_delta_rule
430+
431+
if packed_seq_params is not None:
432+
# Packed sequence forward pass (THD format)
433+
query = _unpack_sequence(query, cu_seqlens_q)
434+
key = _unpack_sequence(key, cu_seqlens_kv)
435+
value = _unpack_sequence(value, cu_seqlens_kv)
436+
g = _unpack_sequence(g, cu_seqlens_q)
437+
beta = _unpack_sequence(beta, cu_seqlens_q)
438+
439+
outputs = []
440+
for i, (q_i, k_i, v_i, g_i, beta_i) in enumerate(zip(query, key, value, g, beta)):
441+
out_i, last_recurrent_state = gated_delta_rule_fn(
442+
q_i,
443+
k_i,
444+
v_i,
445+
g=g_i,
446+
beta=beta_i,
447+
initial_state=None,
448+
output_final_state=False,
449+
use_qk_l2norm_in_kernel=False,
450+
)
451+
outputs.append(out_i)
452+
453+
core_attn_out = torch.cat(outputs, dim=1)
435454
else:
436-
core_attn_out, last_recurrent_state = chunk_gated_delta_rule(
455+
# Regular forward pass (BSHD format)
456+
core_attn_out, last_recurrent_state = gated_delta_rule_fn(
437457
query,
438458
key,
439459
value,
@@ -456,9 +476,19 @@ def forward(
456476
norm_out = norm_out.transpose(0, 1).contiguous()
457477

458478
# CP all to all: HP to CP
459-
norm_out = tensor_a2a_hp2cp(
460-
norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp
461-
)
479+
if packed_seq_params is not None:
480+
unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens_q, dim=0)
481+
outputs = []
482+
for norm_out_i in unpacked_norm_out:
483+
norm_out_i = tensor_a2a_hp2cp(
484+
norm_out_i, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp
485+
)
486+
outputs.append(norm_out_i)
487+
norm_out = torch.cat(outputs, dim=0)
488+
else:
489+
norm_out = tensor_a2a_hp2cp(
490+
norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp
491+
)
462492

463493
# Output projection
464494
nvtx_range_push(suffix="out_proj")
@@ -467,6 +497,53 @@ def forward(
467497

468498
return out, out_bias
469499

500+
def _conv1d_on_qkv(self, qkv):
501+
qkv = qkv.transpose(1, 2).contiguous() # b, s, d -> b, d, s
502+
seq_len = qkv.shape[2]
503+
qkv_channels_split_sections = [
504+
self.qk_dim_local_tp,
505+
self.qk_dim_local_tp,
506+
self.v_dim_local_tp,
507+
]
508+
conv1d_weight = get_parameter_local_cp(
509+
self.conv1d.weight,
510+
dim=0,
511+
cp_group=self.pg_collection.cp,
512+
split_sections=qkv_channels_split_sections,
513+
)
514+
conv1d_bias = (
515+
get_parameter_local_cp(
516+
self.conv1d.bias,
517+
dim=0,
518+
cp_group=self.pg_collection.cp,
519+
split_sections=qkv_channels_split_sections,
520+
)
521+
if self.conv_bias
522+
else None
523+
)
524+
if (causal_conv1d_fn is None) or self.config.deterministic_mode:
525+
conv_out = F.conv1d(
526+
input=qkv,
527+
weight=conv1d_weight,
528+
bias=conv1d_bias,
529+
stride=self.conv1d.stride,
530+
padding=self.conv1d.padding,
531+
dilation=self.conv1d.dilation,
532+
groups=self.conv_dim_local_tp // self.cp_size,
533+
)
534+
qkv = self.act_fn(conv_out[..., :seq_len])
535+
else:
536+
assert self.activation in ["silu", "swish"]
537+
qkv = causal_conv1d_fn(
538+
x=qkv,
539+
weight=conv1d_weight.squeeze(1), # d, 1, w -> d, w
540+
bias=conv1d_bias,
541+
activation=self.activation,
542+
)
543+
qkv = qkv.transpose(1, 2) # b, d, s -> b, s, d
544+
545+
return qkv
546+
470547
@jit_fuser
471548
def _apply_gated_norm(self, x, gate):
472549
# Output Norm
@@ -564,6 +641,17 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None, tp_gr
564641
return sharded_state_dict
565642

566643

644+
def _unpack_sequence(x, cu_seqlens, dim=1):
645+
unpacked_x = []
646+
num_seqs = cu_seqlens.shape[0] - 1
647+
for i in range(num_seqs):
648+
idx_start = cu_seqlens[i].item()
649+
idx_end = cu_seqlens[i + 1].item()
650+
chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)]
651+
unpacked_x.append(x[chunked_index])
652+
return unpacked_x
653+
654+
567655
####################
568656
# Sharded state dict utilities
569657
####################

tests/unit_tests/ssm/test_gated_delta_net.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from tests.unit_tests.test_utilities import Utils
3434
from tests.unit_tests.transformer.test_attention import _test_parallel_attention_correctness
35+
from tests.unit_tests.transformer.test_multi_latent_attention import make_test_packed_seq_params
3536

3637
try:
3738
import fla
@@ -138,6 +139,68 @@ def test_gpu_forward(self):
138139
output.dtype == hidden_states.dtype
139140
), f"Output dtype {output.dtype=} mismatch with {hidden_states.dtype=}"
140141

142+
def test_gpu_forward_thd(self):
143+
# Input shape
144+
sequence_length = 32
145+
micro_batch_size = 4
146+
cu_seqlens = [0, 32, 64, 96, 128]
147+
# sbhd input shape: [sequence length, batch size, hidden size]
148+
sub_sequence_length = sequence_length // self.cp_size // self.sp_size
149+
hidden_states_sbhd = torch.rand(
150+
(sub_sequence_length, micro_batch_size, self.gdn.config.hidden_size)
151+
)
152+
hidden_states_sbhd = hidden_states_sbhd.cuda().bfloat16()
153+
# thd input shape: [sequence length * batch size, 1, hidden size]
154+
hidden_states_thd = hidden_states_sbhd.transpose(0, 1).contiguous()
155+
hidden_states_thd = hidden_states_thd.view(-1, 1, self.gdn.config.hidden_size)
156+
attention_mask = None
157+
packed_seq_params = make_test_packed_seq_params(cu_seqlens=cu_seqlens)
158+
159+
output, _ = self.gdn(hidden_states_thd, attention_mask, packed_seq_params=packed_seq_params)
160+
161+
assert output.shape[0] == sub_sequence_length * micro_batch_size
162+
assert output.shape[1] == 1
163+
assert output.shape[2] == self.gdn.config.hidden_size
164+
165+
def test_gpu_forward_thd_correctness(self):
166+
if self.sp_size > 1:
167+
pytest.skip("Sequence parallel is not supported for this test case.")
168+
169+
atol, rtol = 3e-4, 3e-4
170+
171+
# Input shape
172+
sequence_length = 32
173+
micro_batch_size = 4
174+
cu_seqlens = [0, 32, 64, 96, 128]
175+
# sbhd input shape: [sequence length, batch size, hidden size]
176+
sub_sequence_length = sequence_length // self.cp_size
177+
hidden_states_sbhd = torch.rand(
178+
(sub_sequence_length, micro_batch_size, self.gdn.config.hidden_size)
179+
)
180+
attention_mask_sbhd = None
181+
hidden_states_sbhd = hidden_states_sbhd.cuda().bfloat16()
182+
# thd input shape: [sequence length * batch size, 1, hidden size]
183+
hidden_states_thd = hidden_states_sbhd.transpose(0, 1).contiguous()
184+
hidden_states_thd = hidden_states_thd.view(-1, 1, self.gdn.config.hidden_size)
185+
attention_mask_thd = None
186+
packed_seq_params = make_test_packed_seq_params(cu_seqlens=cu_seqlens)
187+
188+
# SBHD format
189+
output_sbhd, _ = self.gdn(hidden_states_sbhd, attention_mask_sbhd)
190+
# THD format
191+
output_thd, _ = self.gdn(
192+
hidden_states_thd, attention_mask_thd, packed_seq_params=packed_seq_params
193+
)
194+
_output_sbhd = output_sbhd.transpose(0, 1).contiguous().view(*output_thd.shape)
195+
rank = torch.distributed.get_rank()
196+
torch.testing.assert_close(
197+
_output_sbhd,
198+
output_thd,
199+
atol=atol,
200+
rtol=rtol,
201+
msg=lambda msg: f"Output mismatch ({rank=}): {msg}",
202+
)
203+
141204

142205
@pytest.mark.parametrize(
143206
("tp", "sp", "cp"),

0 commit comments

Comments
 (0)