4040from megatron .core .utils import deprecate_inference_params , nvtx_range_pop , nvtx_range_push
4141
4242try :
43+ from fla .modules .convolution import causal_conv1d
4344 from fla .modules .l2norm import l2norm
4445 from fla .ops .gated_delta_rule import chunk_gated_delta_rule
4546
4647 HAVE_FLA = True
4748except ImportError :
49+ causal_conv1d = None
50+ l2norm = None
4851 chunk_gated_delta_rule = None
4952
5053 HAVE_FLA = False
5154
52- try :
53- from causal_conv1d import causal_conv1d_fn
54- except ImportError :
55- causal_conv1d_fn = None
56-
5755
5856logger = logging .getLogger (__name__ )
5957
@@ -204,6 +202,11 @@ def __init__(
204202 )
205203 setattr (self .A_log , "tensor_model_parallel" , True )
206204
205+ if self .config .deterministic_mode :
206+ self .gated_delta_rule = torch_chunk_gated_delta_rule
207+ else :
208+ self .gated_delta_rule = chunk_gated_delta_rule
209+
207210 # Output layernorm before projection
208211 self .out_norm = build_module (
209212 submodules .out_norm ,
@@ -293,29 +296,71 @@ def forward(
293296 raise NotImplementedError ("GDN does not support inference for now." )
294297
295298 if packed_seq_params is not None :
296- # TODO: support packed sequence
297- raise NotImplementedError ("GDN does not support packed sequence for now." )
299+ assert batch == 1 , "Packed sequence expects batch dimension to be 1"
300+ assert (
301+ not self .config .deterministic_mode
302+ ), "Packed sequence does not support deterministic mode."
303+
304+ # Prefer cu_seqlens_q_padded if available, otherwise use cu_seqlens_q
305+ cu_seqlens_q = packed_seq_params .cu_seqlens_q_padded or packed_seq_params .cu_seqlens_q
306+ # Prefer cu_seqlens_kv_padded if available, otherwise use cu_seqlens_kv
307+ cu_seqlens_kv = (
308+ packed_seq_params .cu_seqlens_kv_padded or packed_seq_params .cu_seqlens_kv
309+ )
310+ assert torch .equal (cu_seqlens_q , cu_seqlens_kv ), (
311+ "Currently only support cu_seqlens_q equals to cu_seqlens_kv, "
312+ f"but got { cu_seqlens_q = } and { cu_seqlens_kv = } "
313+ )
314+ num_packed_seqs = cu_seqlens_q .shape [0 ] - 1
315+ assert num_packed_seqs > 0 , (
316+ "Number of packed sequences must be greater than 0, "
317+ f"but got { cu_seqlens_q = } and { cu_seqlens_kv = } "
318+ )
319+ else :
320+ cu_seqlens_q = None
321+ cu_seqlens_kv = None
298322
299323 # Input projection
300324 nvtx_range_push (suffix = "in_proj" )
301325 qkvzba , _ = self .in_proj (hidden_states )
302326 nvtx_range_pop (suffix = "in_proj" )
303327
304328 # CP All to All: CP to HP
305- qkvzba = tensor_a2a_cp2hp (
306- qkvzba ,
307- seq_dim = 0 ,
308- head_dim = - 1 ,
309- cp_group = self .pg_collection .cp ,
310- split_sections = [
311- self .qk_dim_local_tp ,
312- self .qk_dim_local_tp ,
313- self .v_dim_local_tp ,
314- self .v_dim_local_tp ,
315- self .num_value_heads // self .tp_size ,
316- self .num_value_heads // self .tp_size ,
317- ],
318- )
329+ if packed_seq_params is not None :
330+ unpacked_qkvzba = _unpack_sequence (qkvzba , cu_seqlens_q // self .cp_size , dim = 0 )
331+ outputs = []
332+ for qkvzba_i in unpacked_qkvzba :
333+ qkvzba_i = tensor_a2a_cp2hp (
334+ qkvzba_i ,
335+ seq_dim = 0 ,
336+ head_dim = - 1 ,
337+ cp_group = self .pg_collection .cp ,
338+ split_sections = [
339+ self .qk_dim_local_tp ,
340+ self .qk_dim_local_tp ,
341+ self .v_dim_local_tp ,
342+ self .v_dim_local_tp ,
343+ self .num_value_heads // self .tp_size ,
344+ self .num_value_heads // self .tp_size ,
345+ ],
346+ )
347+ outputs .append (qkvzba_i )
348+ qkvzba = torch .cat (outputs , dim = 0 )
349+ else :
350+ qkvzba = tensor_a2a_cp2hp (
351+ qkvzba ,
352+ seq_dim = 0 ,
353+ head_dim = - 1 ,
354+ cp_group = self .pg_collection .cp ,
355+ split_sections = [
356+ self .qk_dim_local_tp ,
357+ self .qk_dim_local_tp ,
358+ self .v_dim_local_tp ,
359+ self .v_dim_local_tp ,
360+ self .num_value_heads // self .tp_size ,
361+ self .num_value_heads // self .tp_size ,
362+ ],
363+ )
319364
320365 # Transpose: s b x --> b s x
321366 # From sbhd to bshd format
@@ -337,51 +382,10 @@ def forward(
337382 alpha = alpha .reshape (batch , seq_len , - 1 )
338383
339384 # Convolution on qkv
340- qkv = qkv .transpose (1 , 2 ).contiguous () # b, s, d -> b, d, s
341385 nvtx_range_push (suffix = "conv1d" )
342- qkv_channels_split_sections = [
343- self .qk_dim_local_tp ,
344- self .qk_dim_local_tp ,
345- self .v_dim_local_tp ,
346- ]
347- conv1d_weight = get_parameter_local_cp (
348- self .conv1d .weight ,
349- dim = 0 ,
350- cp_group = self .pg_collection .cp ,
351- split_sections = qkv_channels_split_sections ,
352- )
353- conv1d_bias = (
354- get_parameter_local_cp (
355- self .conv1d .bias ,
356- dim = 0 ,
357- cp_group = self .pg_collection .cp ,
358- split_sections = qkv_channels_split_sections ,
359- )
360- if self .conv_bias
361- else None
362- )
363- if (causal_conv1d_fn is None ) or self .config .deterministic_mode :
364- conv_out = F .conv1d (
365- input = qkv ,
366- weight = conv1d_weight ,
367- bias = conv1d_bias ,
368- stride = self .conv1d .stride ,
369- padding = self .conv1d .padding ,
370- dilation = self .conv1d .dilation ,
371- groups = self .conv_dim_local_tp // self .cp_size ,
372- )
373- qkv = self .act_fn (conv_out [..., :seq_len ])
374- else :
375- assert self .activation in ["silu" , "swish" ]
376- qkv = causal_conv1d_fn (
377- x = qkv ,
378- weight = conv1d_weight .squeeze (1 ), # d, 1, w -> d, w
379- bias = conv1d_bias ,
380- activation = self .activation ,
381- )
386+ qkv = self ._conv1d_on_qkv (qkv , cu_seqlens = cu_seqlens_q )
382387 nvtx_range_pop (suffix = "conv1d" )
383388 # Split qkv into query, key, and value
384- qkv = qkv .transpose (1 , 2 ) # b, d, s -> b, s, d
385389 query , key , value = torch .split (
386390 qkv ,
387391 [
@@ -421,28 +425,17 @@ def forward(
421425 nvtx_range_pop (suffix = "g_and_beta" )
422426
423427 nvtx_range_push (suffix = "gated_delta_rule" )
424- if self .config .deterministic_mode :
425- core_attn_out , last_recurrent_state = torch_chunk_gated_delta_rule (
426- query ,
427- key ,
428- value ,
429- g = g ,
430- beta = beta ,
431- initial_state = None ,
432- output_final_state = False ,
433- use_qk_l2norm_in_kernel = False ,
434- )
435- else :
436- core_attn_out , last_recurrent_state = chunk_gated_delta_rule (
437- query ,
438- key ,
439- value ,
440- g = g ,
441- beta = beta ,
442- initial_state = None ,
443- output_final_state = False ,
444- use_qk_l2norm_in_kernel = False ,
445- )
428+ core_attn_out , last_recurrent_state = self .gated_delta_rule (
429+ query ,
430+ key ,
431+ value ,
432+ g = g ,
433+ beta = beta ,
434+ initial_state = None ,
435+ output_final_state = False ,
436+ use_qk_l2norm_in_kernel = False ,
437+ cu_seqlens = cu_seqlens_q ,
438+ )
446439 nvtx_range_pop (suffix = "gated_delta_rule" )
447440
448441 # RMSNorm
@@ -456,9 +449,19 @@ def forward(
456449 norm_out = norm_out .transpose (0 , 1 ).contiguous ()
457450
458451 # CP all to all: HP to CP
459- norm_out = tensor_a2a_hp2cp (
460- norm_out , seq_dim = 0 , head_dim = - 1 , cp_group = self .pg_collection .cp
461- )
452+ if packed_seq_params is not None :
453+ unpacked_norm_out = _unpack_sequence (norm_out , cu_seqlens_q , dim = 0 )
454+ outputs = []
455+ for norm_out_i in unpacked_norm_out :
456+ norm_out_i = tensor_a2a_hp2cp (
457+ norm_out_i , seq_dim = 0 , head_dim = - 1 , cp_group = self .pg_collection .cp
458+ )
459+ outputs .append (norm_out_i )
460+ norm_out = torch .cat (outputs , dim = 0 )
461+ else :
462+ norm_out = tensor_a2a_hp2cp (
463+ norm_out , seq_dim = 0 , head_dim = - 1 , cp_group = self .pg_collection .cp
464+ )
462465
463466 # Output projection
464467 nvtx_range_push (suffix = "out_proj" )
@@ -467,6 +470,56 @@ def forward(
467470
468471 return out , out_bias
469472
473+ def _conv1d_on_qkv (self , qkv , cu_seqlens = None ):
474+ seq_len = qkv .shape [1 ]
475+ qkv_channels_split_sections = [
476+ self .qk_dim_local_tp ,
477+ self .qk_dim_local_tp ,
478+ self .v_dim_local_tp ,
479+ ]
480+ conv1d_weight = get_parameter_local_cp (
481+ self .conv1d .weight ,
482+ dim = 0 ,
483+ cp_group = self .pg_collection .cp ,
484+ split_sections = qkv_channels_split_sections ,
485+ )
486+ conv1d_bias = (
487+ get_parameter_local_cp (
488+ self .conv1d .bias ,
489+ dim = 0 ,
490+ cp_group = self .pg_collection .cp ,
491+ split_sections = qkv_channels_split_sections ,
492+ )
493+ if self .conv_bias
494+ else None
495+ )
496+ if self .config .deterministic_mode :
497+ qkv = qkv .transpose (1 , 2 ).contiguous () # b, s, d -> b, d, s
498+ conv_out = F .conv1d (
499+ input = qkv , # Torch-native only accept [b, d, s] format input
500+ weight = conv1d_weight ,
501+ bias = conv1d_bias ,
502+ stride = self .conv1d .stride ,
503+ padding = self .conv1d .padding ,
504+ dilation = self .conv1d .dilation ,
505+ groups = self .conv_dim_local_tp // self .cp_size ,
506+ )
507+ qkv = self .act_fn (conv_out [..., :seq_len ])
508+ qkv = qkv .transpose (1 , 2 ) # b, d, s -> b, s, d
509+ else :
510+ assert self .activation in ["silu" , "swish" ]
511+ qkv , _ = causal_conv1d (
512+ x = qkv , # FLA conv1d accepts [b, s, d] format input
513+ weight = conv1d_weight .squeeze (1 ), # d, 1, w -> d, w
514+ bias = conv1d_bias ,
515+ activation = self .activation ,
516+ initial_state = None ,
517+ output_final_state = False ,
518+ cu_seqlens = cu_seqlens ,
519+ )
520+
521+ return qkv
522+
470523 @jit_fuser
471524 def _apply_gated_norm (self , x , gate ):
472525 # Output Norm
@@ -564,6 +617,17 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None, tp_gr
564617 return sharded_state_dict
565618
566619
620+ def _unpack_sequence (x , cu_seqlens , dim = 1 ):
621+ unpacked_x = []
622+ num_seqs = cu_seqlens .shape [0 ] - 1
623+ for i in range (num_seqs ):
624+ idx_start = cu_seqlens [i ].item ()
625+ idx_end = cu_seqlens [i + 1 ].item ()
626+ chunked_index = [slice (None )] * dim + [slice (idx_start , idx_end )]
627+ unpacked_x .append (x [chunked_index ])
628+ return unpacked_x
629+
630+
567631####################
568632# Sharded state dict utilities
569633####################
0 commit comments