Skip to content

Commit ae8806c

Browse files
support GDN packed sequence
1 parent 6e2153b commit ae8806c

File tree

3 files changed

+213
-92
lines changed

3 files changed

+213
-92
lines changed

megatron/core/ssm/gated_delta_net.py

Lines changed: 152 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,18 @@
4040
from megatron.core.utils import deprecate_inference_params, nvtx_range_pop, nvtx_range_push
4141

4242
try:
43+
from fla.modules.convolution import causal_conv1d
4344
from fla.modules.l2norm import l2norm
4445
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
4546

4647
HAVE_FLA = True
4748
except ImportError:
49+
causal_conv1d = None
50+
l2norm = None
4851
chunk_gated_delta_rule = None
4952

5053
HAVE_FLA = False
5154

52-
try:
53-
from causal_conv1d import causal_conv1d_fn
54-
except ImportError:
55-
causal_conv1d_fn = None
56-
5755

5856
logger = logging.getLogger(__name__)
5957

@@ -204,6 +202,11 @@ def __init__(
204202
)
205203
setattr(self.A_log, "tensor_model_parallel", True)
206204

205+
if self.config.deterministic_mode:
206+
self.gated_delta_rule = torch_chunk_gated_delta_rule
207+
else:
208+
self.gated_delta_rule = chunk_gated_delta_rule
209+
207210
# Output layernorm before projection
208211
self.out_norm = build_module(
209212
submodules.out_norm,
@@ -293,29 +296,71 @@ def forward(
293296
raise NotImplementedError("GDN does not support inference for now.")
294297

295298
if packed_seq_params is not None:
296-
# TODO: support packed sequence
297-
raise NotImplementedError("GDN does not support packed sequence for now.")
299+
assert batch == 1, "Packed sequence expects batch dimension to be 1"
300+
assert (
301+
not self.config.deterministic_mode
302+
), "Packed sequence does not support deterministic mode."
303+
304+
# Prefer cu_seqlens_q_padded if available, otherwise use cu_seqlens_q
305+
cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded or packed_seq_params.cu_seqlens_q
306+
# Prefer cu_seqlens_kv_padded if available, otherwise use cu_seqlens_kv
307+
cu_seqlens_kv = (
308+
packed_seq_params.cu_seqlens_kv_padded or packed_seq_params.cu_seqlens_kv
309+
)
310+
assert torch.equal(cu_seqlens_q, cu_seqlens_kv), (
311+
"Currently only support cu_seqlens_q equals to cu_seqlens_kv, "
312+
f"but got {cu_seqlens_q=} and {cu_seqlens_kv=}"
313+
)
314+
num_packed_seqs = cu_seqlens_q.shape[0] - 1
315+
assert num_packed_seqs > 0, (
316+
"Number of packed sequences must be greater than 0, "
317+
f"but got {cu_seqlens_q=} and {cu_seqlens_kv=}"
318+
)
319+
else:
320+
cu_seqlens_q = None
321+
cu_seqlens_kv = None
298322

299323
# Input projection
300324
nvtx_range_push(suffix="in_proj")
301325
qkvzba, _ = self.in_proj(hidden_states)
302326
nvtx_range_pop(suffix="in_proj")
303327

304328
# CP All to All: CP to HP
305-
qkvzba = tensor_a2a_cp2hp(
306-
qkvzba,
307-
seq_dim=0,
308-
head_dim=-1,
309-
cp_group=self.pg_collection.cp,
310-
split_sections=[
311-
self.qk_dim_local_tp,
312-
self.qk_dim_local_tp,
313-
self.v_dim_local_tp,
314-
self.v_dim_local_tp,
315-
self.num_value_heads // self.tp_size,
316-
self.num_value_heads // self.tp_size,
317-
],
318-
)
329+
if packed_seq_params is not None:
330+
unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens_q // self.cp_size, dim=0)
331+
outputs = []
332+
for qkvzba_i in unpacked_qkvzba:
333+
qkvzba_i = tensor_a2a_cp2hp(
334+
qkvzba_i,
335+
seq_dim=0,
336+
head_dim=-1,
337+
cp_group=self.pg_collection.cp,
338+
split_sections=[
339+
self.qk_dim_local_tp,
340+
self.qk_dim_local_tp,
341+
self.v_dim_local_tp,
342+
self.v_dim_local_tp,
343+
self.num_value_heads // self.tp_size,
344+
self.num_value_heads // self.tp_size,
345+
],
346+
)
347+
outputs.append(qkvzba_i)
348+
qkvzba = torch.cat(outputs, dim=0)
349+
else:
350+
qkvzba = tensor_a2a_cp2hp(
351+
qkvzba,
352+
seq_dim=0,
353+
head_dim=-1,
354+
cp_group=self.pg_collection.cp,
355+
split_sections=[
356+
self.qk_dim_local_tp,
357+
self.qk_dim_local_tp,
358+
self.v_dim_local_tp,
359+
self.v_dim_local_tp,
360+
self.num_value_heads // self.tp_size,
361+
self.num_value_heads // self.tp_size,
362+
],
363+
)
319364

320365
# Transpose: s b x --> b s x
321366
# From sbhd to bshd format
@@ -337,51 +382,10 @@ def forward(
337382
alpha = alpha.reshape(batch, seq_len, -1)
338383

339384
# Convolution on qkv
340-
qkv = qkv.transpose(1, 2).contiguous() # b, s, d -> b, d, s
341385
nvtx_range_push(suffix="conv1d")
342-
qkv_channels_split_sections = [
343-
self.qk_dim_local_tp,
344-
self.qk_dim_local_tp,
345-
self.v_dim_local_tp,
346-
]
347-
conv1d_weight = get_parameter_local_cp(
348-
self.conv1d.weight,
349-
dim=0,
350-
cp_group=self.pg_collection.cp,
351-
split_sections=qkv_channels_split_sections,
352-
)
353-
conv1d_bias = (
354-
get_parameter_local_cp(
355-
self.conv1d.bias,
356-
dim=0,
357-
cp_group=self.pg_collection.cp,
358-
split_sections=qkv_channels_split_sections,
359-
)
360-
if self.conv_bias
361-
else None
362-
)
363-
if (causal_conv1d_fn is None) or self.config.deterministic_mode:
364-
conv_out = F.conv1d(
365-
input=qkv,
366-
weight=conv1d_weight,
367-
bias=conv1d_bias,
368-
stride=self.conv1d.stride,
369-
padding=self.conv1d.padding,
370-
dilation=self.conv1d.dilation,
371-
groups=self.conv_dim_local_tp // self.cp_size,
372-
)
373-
qkv = self.act_fn(conv_out[..., :seq_len])
374-
else:
375-
assert self.activation in ["silu", "swish"]
376-
qkv = causal_conv1d_fn(
377-
x=qkv,
378-
weight=conv1d_weight.squeeze(1), # d, 1, w -> d, w
379-
bias=conv1d_bias,
380-
activation=self.activation,
381-
)
386+
qkv = self._conv1d_on_qkv(qkv, cu_seqlens=cu_seqlens_q)
382387
nvtx_range_pop(suffix="conv1d")
383388
# Split qkv into query, key, and value
384-
qkv = qkv.transpose(1, 2) # b, d, s -> b, s, d
385389
query, key, value = torch.split(
386390
qkv,
387391
[
@@ -421,28 +425,17 @@ def forward(
421425
nvtx_range_pop(suffix="g_and_beta")
422426

423427
nvtx_range_push(suffix="gated_delta_rule")
424-
if self.config.deterministic_mode:
425-
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule(
426-
query,
427-
key,
428-
value,
429-
g=g,
430-
beta=beta,
431-
initial_state=None,
432-
output_final_state=False,
433-
use_qk_l2norm_in_kernel=False,
434-
)
435-
else:
436-
core_attn_out, last_recurrent_state = chunk_gated_delta_rule(
437-
query,
438-
key,
439-
value,
440-
g=g,
441-
beta=beta,
442-
initial_state=None,
443-
output_final_state=False,
444-
use_qk_l2norm_in_kernel=False,
445-
)
428+
core_attn_out, last_recurrent_state = self.gated_delta_rule(
429+
query,
430+
key,
431+
value,
432+
g=g,
433+
beta=beta,
434+
initial_state=None,
435+
output_final_state=False,
436+
use_qk_l2norm_in_kernel=False,
437+
cu_seqlens=cu_seqlens_q,
438+
)
446439
nvtx_range_pop(suffix="gated_delta_rule")
447440

448441
# RMSNorm
@@ -456,9 +449,19 @@ def forward(
456449
norm_out = norm_out.transpose(0, 1).contiguous()
457450

458451
# CP all to all: HP to CP
459-
norm_out = tensor_a2a_hp2cp(
460-
norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp
461-
)
452+
if packed_seq_params is not None:
453+
unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens_q, dim=0)
454+
outputs = []
455+
for norm_out_i in unpacked_norm_out:
456+
norm_out_i = tensor_a2a_hp2cp(
457+
norm_out_i, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp
458+
)
459+
outputs.append(norm_out_i)
460+
norm_out = torch.cat(outputs, dim=0)
461+
else:
462+
norm_out = tensor_a2a_hp2cp(
463+
norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp
464+
)
462465

463466
# Output projection
464467
nvtx_range_push(suffix="out_proj")
@@ -467,6 +470,56 @@ def forward(
467470

468471
return out, out_bias
469472

473+
def _conv1d_on_qkv(self, qkv, cu_seqlens=None):
474+
seq_len = qkv.shape[1]
475+
qkv_channels_split_sections = [
476+
self.qk_dim_local_tp,
477+
self.qk_dim_local_tp,
478+
self.v_dim_local_tp,
479+
]
480+
conv1d_weight = get_parameter_local_cp(
481+
self.conv1d.weight,
482+
dim=0,
483+
cp_group=self.pg_collection.cp,
484+
split_sections=qkv_channels_split_sections,
485+
)
486+
conv1d_bias = (
487+
get_parameter_local_cp(
488+
self.conv1d.bias,
489+
dim=0,
490+
cp_group=self.pg_collection.cp,
491+
split_sections=qkv_channels_split_sections,
492+
)
493+
if self.conv_bias
494+
else None
495+
)
496+
if self.config.deterministic_mode:
497+
qkv = qkv.transpose(1, 2).contiguous() # b, s, d -> b, d, s
498+
conv_out = F.conv1d(
499+
input=qkv, # Torch-native only accept [b, d, s] format input
500+
weight=conv1d_weight,
501+
bias=conv1d_bias,
502+
stride=self.conv1d.stride,
503+
padding=self.conv1d.padding,
504+
dilation=self.conv1d.dilation,
505+
groups=self.conv_dim_local_tp // self.cp_size,
506+
)
507+
qkv = self.act_fn(conv_out[..., :seq_len])
508+
qkv = qkv.transpose(1, 2) # b, d, s -> b, s, d
509+
else:
510+
assert self.activation in ["silu", "swish"]
511+
qkv, _ = causal_conv1d(
512+
x=qkv, # FLA conv1d accepts [b, s, d] format input
513+
weight=conv1d_weight.squeeze(1), # d, 1, w -> d, w
514+
bias=conv1d_bias,
515+
activation=self.activation,
516+
initial_state=None,
517+
output_final_state=False,
518+
cu_seqlens=cu_seqlens,
519+
)
520+
521+
return qkv
522+
470523
@jit_fuser
471524
def _apply_gated_norm(self, x, gate):
472525
# Output Norm
@@ -564,6 +617,17 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None, tp_gr
564617
return sharded_state_dict
565618

566619

620+
def _unpack_sequence(x, cu_seqlens, dim=1):
621+
unpacked_x = []
622+
num_seqs = cu_seqlens.shape[0] - 1
623+
for i in range(num_seqs):
624+
idx_start = cu_seqlens[i].item()
625+
idx_end = cu_seqlens[i + 1].item()
626+
chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)]
627+
unpacked_x.append(x[chunked_index])
628+
return unpacked_x
629+
630+
567631
####################
568632
# Sharded state dict utilities
569633
####################

tests/unit_tests/ssm/test_gated_delta_net.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from tests.unit_tests.test_utilities import Utils
3434
from tests.unit_tests.transformer.test_attention import _test_parallel_attention_correctness
35+
from tests.unit_tests.transformer.test_multi_latent_attention import make_test_packed_seq_params
3536

3637
try:
3738
import fla
@@ -138,7 +139,51 @@ def test_gpu_forward(self):
138139
output.dtype == hidden_states.dtype
139140
), f"Output dtype {output.dtype=} mismatch with {hidden_states.dtype=}"
140141

142+
def test_gpu_forward_thd_correctness(self):
143+
if self.sp_size > 1:
144+
pytest.skip("Sequence parallel is not supported for this test case.")
141145

146+
atol, rtol = 3e-4, 3e-4
147+
148+
# Input shape
149+
sequence_length = 32
150+
micro_batch_size = 4
151+
cu_seqlens = [0, 32, 64, 96, 128]
152+
# sbhd input shape: [sequence length, batch size, hidden size]
153+
sub_sequence_length = sequence_length // self.cp_size
154+
hidden_states_sbhd = torch.rand(
155+
(sub_sequence_length, micro_batch_size, self.gdn.config.hidden_size)
156+
)
157+
attention_mask_sbhd = None
158+
hidden_states_sbhd = hidden_states_sbhd.cuda().bfloat16()
159+
# thd input shape: [sequence length * batch size, 1, hidden size]
160+
hidden_states_thd = hidden_states_sbhd.transpose(0, 1).contiguous()
161+
hidden_states_thd = hidden_states_thd.view(-1, 1, self.gdn.config.hidden_size)
162+
attention_mask_thd = None
163+
packed_seq_params = make_test_packed_seq_params(cu_seqlens=cu_seqlens)
164+
165+
# THD format
166+
output_thd, _ = self.gdn(
167+
hidden_states_thd, attention_mask_thd, packed_seq_params=packed_seq_params
168+
)
169+
# SBHD format
170+
output_sbhd, _ = self.gdn(hidden_states_sbhd, attention_mask_sbhd)
171+
output_sbhd_T = output_sbhd.transpose(0, 1).contiguous().view(*output_thd.shape)
172+
173+
rank = torch.distributed.get_rank()
174+
assert output_thd.shape[0] == sub_sequence_length * micro_batch_size
175+
assert output_thd.shape[1] == 1
176+
assert output_thd.shape[2] == self.gdn.config.hidden_size
177+
torch.testing.assert_close(
178+
output_sbhd_T,
179+
output_thd,
180+
atol=atol,
181+
rtol=rtol,
182+
msg=lambda msg: f"Output mismatch ({rank=}): {msg}",
183+
)
184+
185+
186+
@pytest.mark.parametrize("sequence_packing", [False, True])
142187
@pytest.mark.parametrize(
143188
("tp", "sp", "cp"),
144189
[
@@ -150,7 +195,7 @@ def test_gpu_forward(self):
150195
],
151196
)
152197
@pytest.mark.skipif(not HAVE_FLA, reason="FLA is not installed.")
153-
def test_parallel_gated_delta_net_correctness(tmp_path_dist_ckpt, tp, sp, cp):
198+
def test_parallel_gated_delta_net_correctness(tmp_path_dist_ckpt, sequence_packing, tp, sp, cp):
154199
transformer_config = TransformerConfig(
155200
hidden_size=128,
156201
linear_conv_kernel_dim=2,
@@ -191,4 +236,5 @@ def test_parallel_gated_delta_net_correctness(tmp_path_dist_ckpt, tp, sp, cp):
191236
seed=123,
192237
sequence_length=256,
193238
micro_batch_size=4,
239+
sequence_packing=sequence_packing,
194240
)

0 commit comments

Comments
 (0)