Skip to content

Commit 2575c6d

Browse files
support GDN packed sequence
1 parent 02ea26d commit 2575c6d

File tree

2 files changed

+216
-65
lines changed

2 files changed

+216
-65
lines changed

megatron/core/ssm/gated_delta_net.py

Lines changed: 152 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,11 @@ def __init__(
9999
cp_comm_type: No use for GDN, just for compatibility with Attention class.
100100
"""
101101

102-
if not HAVE_FLA:
103-
raise ImportError("FLA is not installed. Please install it with `pip install fla`.")
102+
if not HAVE_FLA and not self.config.deterministic_mode:
103+
raise ImportError(
104+
"FLA is not installed. Please install it with "
105+
"`pip install fla` or use deterministic mode."
106+
)
104107

105108
super().__init__(config)
106109

@@ -304,28 +307,62 @@ def forward(
304307
raise NotImplementedError("GDN does not support inference for now.")
305308

306309
if packed_seq_params is not None:
307-
# TODO: support packed sequence
308-
raise NotImplementedError("GDN does not support packed sequence for now.")
310+
assert batch == 1, "Packed sequence expects batch dimension to be 1"
311+
# Prefer cu_seqlens_q_padded if available, otherwise use cu_seqlens_q
312+
cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded or packed_seq_params.cu_seqlens_q
313+
# Prefer cu_seqlens_kv_padded if available, otherwise use cu_seqlens_kv
314+
cu_seqlens_kv = (
315+
packed_seq_params.cu_seqlens_kv_padded or packed_seq_params.cu_seqlens_kv
316+
)
317+
assert torch.equal(cu_seqlens_q, cu_seqlens_kv), (
318+
"Currently only support cu_seqlens_q equals to cu_seqlens_kv, "
319+
f"but got {cu_seqlens_q=} and {cu_seqlens_kv=}"
320+
)
321+
num_packed_seqs = cu_seqlens_q.shape[0] - 1
322+
assert num_packed_seqs > 0, (
323+
"Number of packed sequences must be greater than 0, "
324+
f"but got {cu_seqlens_q=} and {cu_seqlens_kv=}"
325+
)
309326

310327
# Input projection
311328
nvtx_range_push(suffix="in_proj")
312329
qkvzba, _ = self.in_proj(hidden_states)
313330
nvtx_range_pop(suffix="in_proj")
314331

315332
# CP All to All: CP to HP
316-
qkvzba = self.cp.tensor_a2a_cp2hp(
317-
qkvzba,
318-
seq_dim=0,
319-
head_dim=-1,
320-
split_size_or_sections=[
321-
self.qk_dim_local_tp,
322-
self.qk_dim_local_tp,
323-
self.v_dim_local_tp,
324-
self.v_dim_local_tp,
325-
self.num_value_heads // self.tp_size,
326-
self.num_value_heads // self.tp_size,
327-
],
328-
)
333+
if packed_seq_params is not None:
334+
unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens_q // self.cp_size, dim=0)
335+
outputs = []
336+
for qkvzba_i in unpacked_qkvzba:
337+
qkvzba_i = self.cp.tensor_a2a_cp2hp(
338+
qkvzba_i,
339+
seq_dim=0,
340+
head_dim=-1,
341+
split_size_or_sections=[
342+
self.qk_dim_local_tp,
343+
self.qk_dim_local_tp,
344+
self.v_dim_local_tp,
345+
self.v_dim_local_tp,
346+
self.num_value_heads // self.tp_size,
347+
self.num_value_heads // self.tp_size,
348+
],
349+
)
350+
outputs.append(qkvzba_i)
351+
qkvzba = torch.cat(outputs, dim=0)
352+
else:
353+
qkvzba = self.cp.tensor_a2a_cp2hp(
354+
qkvzba,
355+
seq_dim=0,
356+
head_dim=-1,
357+
split_size_or_sections=[
358+
self.qk_dim_local_tp,
359+
self.qk_dim_local_tp,
360+
self.v_dim_local_tp,
361+
self.v_dim_local_tp,
362+
self.num_value_heads // self.tp_size,
363+
self.num_value_heads // self.tp_size,
364+
],
365+
)
329366

330367
# Transpose: s b x --> b s x
331368
# From sbhd to bshd format
@@ -347,45 +384,18 @@ def forward(
347384
alpha = alpha.reshape(batch, seq_len, -1)
348385

349386
# Convolution on qkv
350-
qkv = qkv.transpose(1, 2).contiguous() # b, s, d -> b, d, s
351387
nvtx_range_push(suffix="conv1d")
352-
qkv_channels_split_sections = [
353-
self.qk_dim_local_tp,
354-
self.qk_dim_local_tp,
355-
self.v_dim_local_tp,
356-
]
357-
conv1d_weight = self.cp.get_parameter_local_cp(
358-
self.conv1d.weight, dim=0, split_size_or_sections=qkv_channels_split_sections
359-
)
360-
conv1d_bias = (
361-
self.cp.get_parameter_local_cp(
362-
self.conv1d.bias, dim=0, split_size_or_sections=qkv_channels_split_sections
363-
)
364-
if self.conv_bias
365-
else None
366-
)
367-
if (causal_conv1d_fn is None) or self.config.deterministic_mode:
368-
conv_out = F.conv1d(
369-
input=qkv,
370-
weight=conv1d_weight,
371-
bias=conv1d_bias,
372-
stride=self.conv1d.stride,
373-
padding=self.conv1d.padding,
374-
dilation=self.conv1d.dilation,
375-
groups=self.conv_dim_local_tp // self.cp_size,
376-
)
377-
qkv = self.act_fn(conv_out[..., :seq_len])
388+
if packed_seq_params is not None:
389+
unpacked_qkv = _unpack_sequence(qkv, cu_seqlens_q)
390+
outputs = []
391+
for qkv_i in unpacked_qkv:
392+
qkv_i = self._conv1d_on_qkv(qkv_i)
393+
outputs.append(qkv_i)
394+
qkv = torch.cat(outputs, dim=1)
378395
else:
379-
assert self.activation in ["silu", "swish"]
380-
qkv = causal_conv1d_fn(
381-
x=qkv,
382-
weight=conv1d_weight.squeeze(1), # d, 1, w -> d, w
383-
bias=conv1d_bias,
384-
activation=self.activation,
385-
)
396+
qkv = self._conv1d_on_qkv(qkv)
386397
nvtx_range_pop(suffix="conv1d")
387398
# Split qkv into query, key, and value
388-
qkv = qkv.transpose(1, 2) # b, d, s -> b, s, d
389399
query, key, value = torch.split(
390400
qkv,
391401
[
@@ -424,18 +434,36 @@ def forward(
424434

425435
nvtx_range_push(suffix="gated_delta_rule")
426436
if self.config.deterministic_mode:
427-
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule(
428-
query,
429-
key,
430-
value,
431-
g=g,
432-
beta=beta,
433-
initial_state=None,
434-
output_final_state=False,
435-
use_qk_l2norm_in_kernel=False,
436-
)
437+
gated_delta_rule_fn = torch_chunk_gated_delta_rule
437438
else:
438-
core_attn_out, last_recurrent_state = chunk_gated_delta_rule(
439+
gated_delta_rule_fn = chunk_gated_delta_rule
440+
441+
if packed_seq_params is not None:
442+
# Packed sequence forward pass (THD format)
443+
query = _unpack_sequence(query, cu_seqlens_q)
444+
key = _unpack_sequence(key, cu_seqlens_kv)
445+
value = _unpack_sequence(value, cu_seqlens_kv)
446+
g = _unpack_sequence(g, cu_seqlens_q)
447+
beta = _unpack_sequence(beta, cu_seqlens_q)
448+
449+
outputs = []
450+
for i, (q_i, k_i, v_i, g_i, beta_i) in enumerate(zip(query, key, value, g, beta)):
451+
out_i, last_recurrent_state = gated_delta_rule_fn(
452+
q_i,
453+
k_i,
454+
v_i,
455+
g=g_i,
456+
beta=beta_i,
457+
initial_state=None,
458+
output_final_state=False,
459+
use_qk_l2norm_in_kernel=False,
460+
)
461+
outputs.append(out_i)
462+
463+
core_attn_out = torch.cat(outputs, dim=1)
464+
else:
465+
# Regular forward pass (BSHD format)
466+
core_attn_out, last_recurrent_state = gated_delta_rule_fn(
439467
query,
440468
key,
441469
value,
@@ -458,7 +486,15 @@ def forward(
458486
norm_out = norm_out.transpose(0, 1).contiguous()
459487

460488
# CP all to all: HP to CP
461-
norm_out = self.cp.tensor_a2a_hp2cp(norm_out, seq_dim=0, head_dim=-1)
489+
if packed_seq_params is not None:
490+
unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens_q, dim=0)
491+
outputs = []
492+
for norm_out_i in unpacked_norm_out:
493+
norm_out_i = self.cp.tensor_a2a_hp2cp(norm_out_i, seq_dim=0, head_dim=-1)
494+
outputs.append(norm_out_i)
495+
norm_out = torch.cat(outputs, dim=0)
496+
else:
497+
norm_out = self.cp.tensor_a2a_hp2cp(norm_out, seq_dim=0, head_dim=-1)
462498

463499
# Output projection
464500
nvtx_range_push(suffix="out_proj")
@@ -467,6 +503,47 @@ def forward(
467503

468504
return out, out_bias
469505

506+
def _conv1d_on_qkv(self, qkv):
507+
qkv = qkv.transpose(1, 2).contiguous() # b, s, d -> b, d, s
508+
seq_len = qkv.shape[2]
509+
qkv_channels_split_sections = [
510+
self.qk_dim_local_tp,
511+
self.qk_dim_local_tp,
512+
self.v_dim_local_tp,
513+
]
514+
conv1d_weight = self.cp.get_parameter_local_cp(
515+
self.conv1d.weight, dim=0, split_size_or_sections=qkv_channels_split_sections
516+
)
517+
conv1d_bias = (
518+
self.cp.get_parameter_local_cp(
519+
self.conv1d.bias, dim=0, split_size_or_sections=qkv_channels_split_sections
520+
)
521+
if self.conv_bias
522+
else None
523+
)
524+
if (causal_conv1d_fn is None) or self.config.deterministic_mode:
525+
conv_out = F.conv1d(
526+
input=qkv,
527+
weight=conv1d_weight,
528+
bias=conv1d_bias,
529+
stride=self.conv1d.stride,
530+
padding=self.conv1d.padding,
531+
dilation=self.conv1d.dilation,
532+
groups=self.conv_dim_local_tp // self.cp_size,
533+
)
534+
qkv = self.act_fn(conv_out[..., :seq_len])
535+
else:
536+
assert self.activation in ["silu", "swish"]
537+
qkv = causal_conv1d_fn(
538+
x=qkv,
539+
weight=conv1d_weight.squeeze(1), # d, 1, w -> d, w
540+
bias=conv1d_bias,
541+
activation=self.activation,
542+
)
543+
qkv = qkv.transpose(1, 2) # b, d, s -> b, s, d
544+
545+
return qkv
546+
470547
@jit_fuser
471548
def _apply_gated_norm(self, x, gate):
472549
# Output Norm
@@ -564,6 +641,17 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None, tp_gr
564641
return sharded_state_dict
565642

566643

644+
def _unpack_sequence(x, cu_seqlens, dim=1):
645+
unpacked_x = []
646+
num_seqs = cu_seqlens.shape[0] - 1
647+
for i in range(num_seqs):
648+
idx_start = cu_seqlens[i].item()
649+
idx_end = cu_seqlens[i + 1].item()
650+
chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)]
651+
unpacked_x.append(x[chunked_index])
652+
return unpacked_x
653+
654+
567655
def _split_tensor_factory(
568656
orig_sh_ten: ShardedTensor, split_sections: List[int], split_names: List[str], split_dim: int
569657
) -> ShardedTensorFactory:

tests/unit_tests/ssm/test_gated_delta_net.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
)
3030
from tests.unit_tests.test_utilities import Utils
3131
from tests.unit_tests.transformer.test_attention import _test_parallel_attention_correctness
32+
from tests.unit_tests.transformer.test_multi_latent_attention import make_test_packed_seq_params
3233

3334
try:
3435
import fla
@@ -132,6 +133,68 @@ def test_gpu_forward(self):
132133
output.dtype == hidden_states.dtype
133134
), f"Output dtype {output.dtype=} mismatch with {hidden_states.dtype=}"
134135

136+
def test_gpu_forward_thd(self):
137+
# Input shape
138+
sequence_length = 32
139+
micro_batch_size = 4
140+
cu_seqlens = [0, 32, 64, 96, 128]
141+
# sbhd input shape: [sequence length, batch size, hidden size]
142+
sub_sequence_length = sequence_length // self.cp_size // self.sp_size
143+
hidden_states_sbhd = torch.rand(
144+
(sub_sequence_length, micro_batch_size, self.gdn.config.hidden_size)
145+
)
146+
hidden_states_sbhd = hidden_states_sbhd.cuda().bfloat16()
147+
# thd input shape: [sequence length * batch size, 1, hidden size]
148+
hidden_states_thd = hidden_states_sbhd.transpose(0, 1).contiguous()
149+
hidden_states_thd = hidden_states_thd.view(-1, 1, self.gdn.config.hidden_size)
150+
attention_mask = None
151+
packed_seq_params = make_test_packed_seq_params(cu_seqlens=cu_seqlens)
152+
153+
output, _ = self.gdn(hidden_states_thd, attention_mask, packed_seq_params=packed_seq_params)
154+
155+
assert output.shape[0] == sub_sequence_length * micro_batch_size
156+
assert output.shape[1] == 1
157+
assert output.shape[2] == self.gdn.config.hidden_size
158+
159+
def test_gpu_forward_thd_correctness(self):
160+
if self.sp_size > 1:
161+
pytest.skip("Sequence parallel is not supported for this test case.")
162+
163+
atol, rtol = 3e-4, 3e-4
164+
165+
# Input shape
166+
sequence_length = 32
167+
micro_batch_size = 4
168+
cu_seqlens = [0, 32, 64, 96, 128]
169+
# sbhd input shape: [sequence length, batch size, hidden size]
170+
sub_sequence_length = sequence_length // self.cp_size
171+
hidden_states_sbhd = torch.rand(
172+
(sub_sequence_length, micro_batch_size, self.gdn.config.hidden_size)
173+
)
174+
attention_mask_sbhd = None
175+
hidden_states_sbhd = hidden_states_sbhd.cuda().bfloat16()
176+
# thd input shape: [sequence length * batch size, 1, hidden size]
177+
hidden_states_thd = hidden_states_sbhd.transpose(0, 1).contiguous()
178+
hidden_states_thd = hidden_states_thd.view(-1, 1, self.gdn.config.hidden_size)
179+
attention_mask_thd = None
180+
packed_seq_params = make_test_packed_seq_params(cu_seqlens=cu_seqlens)
181+
182+
# SBHD format
183+
output_sbhd, _ = self.gdn(hidden_states_sbhd, attention_mask_sbhd)
184+
# THD format
185+
output_thd, _ = self.gdn(
186+
hidden_states_thd, attention_mask_thd, packed_seq_params=packed_seq_params
187+
)
188+
_output_sbhd = output_sbhd.transpose(0, 1).contiguous().view(*output_thd.shape)
189+
rank = torch.distributed.get_rank()
190+
torch.testing.assert_close(
191+
_output_sbhd,
192+
output_thd,
193+
atol=atol,
194+
rtol=rtol,
195+
msg=lambda msg: f"Output mismatch ({rank=}): {msg}",
196+
)
197+
135198

136199
@pytest.mark.parametrize(
137200
("tp", "sp", "cp"),
@@ -146,7 +209,7 @@ def test_gpu_forward(self):
146209
@pytest.mark.skipif(not HAVE_FLA, reason="FLA is not installed.")
147210
def test_parallel_gated_delta_net_correctness(tmp_path_dist_ckpt, tp, sp, cp):
148211
transformer_config = TransformerConfig(
149-
hidden_size=hidden_size,
212+
hidden_size=128,
150213
linear_conv_kernel_dim=2,
151214
linear_key_head_dim=32,
152215
linear_value_head_dim=32,

0 commit comments

Comments
 (0)