@@ -99,8 +99,11 @@ def __init__(
9999 cp_comm_type: No use for GDN, just for compatibility with Attention class.
100100 """
101101
102- if not HAVE_FLA :
103- raise ImportError ("FLA is not installed. Please install it with `pip install fla`." )
102+ if not HAVE_FLA and not self .config .deterministic_mode :
103+ raise ImportError (
104+ "FLA is not installed. Please install it with "
105+ "`pip install fla` or use deterministic mode."
106+ )
104107
105108 super ().__init__ (config )
106109
@@ -304,28 +307,62 @@ def forward(
304307 raise NotImplementedError ("GDN does not support inference for now." )
305308
306309 if packed_seq_params is not None :
307- # TODO: support packed sequence
308- raise NotImplementedError ("GDN does not support packed sequence for now." )
310+ assert batch == 1 , "Packed sequence expects batch dimension to be 1"
311+ # Prefer cu_seqlens_q_padded if available, otherwise use cu_seqlens_q
312+ cu_seqlens_q = packed_seq_params .cu_seqlens_q_padded or packed_seq_params .cu_seqlens_q
313+ # Prefer cu_seqlens_kv_padded if available, otherwise use cu_seqlens_kv
314+ cu_seqlens_kv = (
315+ packed_seq_params .cu_seqlens_kv_padded or packed_seq_params .cu_seqlens_kv
316+ )
317+ assert torch .equal (cu_seqlens_q , cu_seqlens_kv ), (
318+ "Currently only support cu_seqlens_q equals to cu_seqlens_kv, "
319+ f"but got { cu_seqlens_q = } and { cu_seqlens_kv = } "
320+ )
321+ num_packed_seqs = cu_seqlens_q .shape [0 ] - 1
322+ assert num_packed_seqs > 0 , (
323+ "Number of packed sequences must be greater than 0, "
324+ f"but got { cu_seqlens_q = } and { cu_seqlens_kv = } "
325+ )
309326
310327 # Input projection
311328 nvtx_range_push (suffix = "in_proj" )
312329 qkvzba , _ = self .in_proj (hidden_states )
313330 nvtx_range_pop (suffix = "in_proj" )
314331
315332 # CP All to All: CP to HP
316- qkvzba = self .cp .tensor_a2a_cp2hp (
317- qkvzba ,
318- seq_dim = 0 ,
319- head_dim = - 1 ,
320- split_size_or_sections = [
321- self .qk_dim_local_tp ,
322- self .qk_dim_local_tp ,
323- self .v_dim_local_tp ,
324- self .v_dim_local_tp ,
325- self .num_value_heads // self .tp_size ,
326- self .num_value_heads // self .tp_size ,
327- ],
328- )
333+ if packed_seq_params is not None :
334+ unpacked_qkvzba = _unpack_sequence (qkvzba , cu_seqlens_q // self .cp_size , dim = 0 )
335+ outputs = []
336+ for qkvzba_i in unpacked_qkvzba :
337+ qkvzba_i = self .cp .tensor_a2a_cp2hp (
338+ qkvzba_i ,
339+ seq_dim = 0 ,
340+ head_dim = - 1 ,
341+ split_size_or_sections = [
342+ self .qk_dim_local_tp ,
343+ self .qk_dim_local_tp ,
344+ self .v_dim_local_tp ,
345+ self .v_dim_local_tp ,
346+ self .num_value_heads // self .tp_size ,
347+ self .num_value_heads // self .tp_size ,
348+ ],
349+ )
350+ outputs .append (qkvzba_i )
351+ qkvzba = torch .cat (outputs , dim = 0 )
352+ else :
353+ qkvzba = self .cp .tensor_a2a_cp2hp (
354+ qkvzba ,
355+ seq_dim = 0 ,
356+ head_dim = - 1 ,
357+ split_size_or_sections = [
358+ self .qk_dim_local_tp ,
359+ self .qk_dim_local_tp ,
360+ self .v_dim_local_tp ,
361+ self .v_dim_local_tp ,
362+ self .num_value_heads // self .tp_size ,
363+ self .num_value_heads // self .tp_size ,
364+ ],
365+ )
329366
330367 # Transpose: s b x --> b s x
331368 # From sbhd to bshd format
@@ -347,45 +384,18 @@ def forward(
347384 alpha = alpha .reshape (batch , seq_len , - 1 )
348385
349386 # Convolution on qkv
350- qkv = qkv .transpose (1 , 2 ).contiguous () # b, s, d -> b, d, s
351387 nvtx_range_push (suffix = "conv1d" )
352- qkv_channels_split_sections = [
353- self .qk_dim_local_tp ,
354- self .qk_dim_local_tp ,
355- self .v_dim_local_tp ,
356- ]
357- conv1d_weight = self .cp .get_parameter_local_cp (
358- self .conv1d .weight , dim = 0 , split_size_or_sections = qkv_channels_split_sections
359- )
360- conv1d_bias = (
361- self .cp .get_parameter_local_cp (
362- self .conv1d .bias , dim = 0 , split_size_or_sections = qkv_channels_split_sections
363- )
364- if self .conv_bias
365- else None
366- )
367- if (causal_conv1d_fn is None ) or self .config .deterministic_mode :
368- conv_out = F .conv1d (
369- input = qkv ,
370- weight = conv1d_weight ,
371- bias = conv1d_bias ,
372- stride = self .conv1d .stride ,
373- padding = self .conv1d .padding ,
374- dilation = self .conv1d .dilation ,
375- groups = self .conv_dim_local_tp // self .cp_size ,
376- )
377- qkv = self .act_fn (conv_out [..., :seq_len ])
388+ if packed_seq_params is not None :
389+ unpacked_qkv = _unpack_sequence (qkv , cu_seqlens_q )
390+ outputs = []
391+ for qkv_i in unpacked_qkv :
392+ qkv_i = self ._conv1d_on_qkv (qkv_i )
393+ outputs .append (qkv_i )
394+ qkv = torch .cat (outputs , dim = 1 )
378395 else :
379- assert self .activation in ["silu" , "swish" ]
380- qkv = causal_conv1d_fn (
381- x = qkv ,
382- weight = conv1d_weight .squeeze (1 ), # d, 1, w -> d, w
383- bias = conv1d_bias ,
384- activation = self .activation ,
385- )
396+ qkv = self ._conv1d_on_qkv (qkv )
386397 nvtx_range_pop (suffix = "conv1d" )
387398 # Split qkv into query, key, and value
388- qkv = qkv .transpose (1 , 2 ) # b, d, s -> b, s, d
389399 query , key , value = torch .split (
390400 qkv ,
391401 [
@@ -424,18 +434,36 @@ def forward(
424434
425435 nvtx_range_push (suffix = "gated_delta_rule" )
426436 if self .config .deterministic_mode :
427- core_attn_out , last_recurrent_state = torch_chunk_gated_delta_rule (
428- query ,
429- key ,
430- value ,
431- g = g ,
432- beta = beta ,
433- initial_state = None ,
434- output_final_state = False ,
435- use_qk_l2norm_in_kernel = False ,
436- )
437+ gated_delta_rule_fn = torch_chunk_gated_delta_rule
437438 else :
438- core_attn_out , last_recurrent_state = chunk_gated_delta_rule (
439+ gated_delta_rule_fn = chunk_gated_delta_rule
440+
441+ if packed_seq_params is not None :
442+ # Packed sequence forward pass (THD format)
443+ query = _unpack_sequence (query , cu_seqlens_q )
444+ key = _unpack_sequence (key , cu_seqlens_kv )
445+ value = _unpack_sequence (value , cu_seqlens_kv )
446+ g = _unpack_sequence (g , cu_seqlens_q )
447+ beta = _unpack_sequence (beta , cu_seqlens_q )
448+
449+ outputs = []
450+ for i , (q_i , k_i , v_i , g_i , beta_i ) in enumerate (zip (query , key , value , g , beta )):
451+ out_i , last_recurrent_state = gated_delta_rule_fn (
452+ q_i ,
453+ k_i ,
454+ v_i ,
455+ g = g_i ,
456+ beta = beta_i ,
457+ initial_state = None ,
458+ output_final_state = False ,
459+ use_qk_l2norm_in_kernel = False ,
460+ )
461+ outputs .append (out_i )
462+
463+ core_attn_out = torch .cat (outputs , dim = 1 )
464+ else :
465+ # Regular forward pass (BSHD format)
466+ core_attn_out , last_recurrent_state = gated_delta_rule_fn (
439467 query ,
440468 key ,
441469 value ,
@@ -458,7 +486,15 @@ def forward(
458486 norm_out = norm_out .transpose (0 , 1 ).contiguous ()
459487
460488 # CP all to all: HP to CP
461- norm_out = self .cp .tensor_a2a_hp2cp (norm_out , seq_dim = 0 , head_dim = - 1 )
489+ if packed_seq_params is not None :
490+ unpacked_norm_out = _unpack_sequence (norm_out , cu_seqlens_q , dim = 0 )
491+ outputs = []
492+ for norm_out_i in unpacked_norm_out :
493+ norm_out_i = self .cp .tensor_a2a_hp2cp (norm_out_i , seq_dim = 0 , head_dim = - 1 )
494+ outputs .append (norm_out_i )
495+ norm_out = torch .cat (outputs , dim = 0 )
496+ else :
497+ norm_out = self .cp .tensor_a2a_hp2cp (norm_out , seq_dim = 0 , head_dim = - 1 )
462498
463499 # Output projection
464500 nvtx_range_push (suffix = "out_proj" )
@@ -467,6 +503,47 @@ def forward(
467503
468504 return out , out_bias
469505
506+ def _conv1d_on_qkv (self , qkv ):
507+ qkv = qkv .transpose (1 , 2 ).contiguous () # b, s, d -> b, d, s
508+ seq_len = qkv .shape [2 ]
509+ qkv_channels_split_sections = [
510+ self .qk_dim_local_tp ,
511+ self .qk_dim_local_tp ,
512+ self .v_dim_local_tp ,
513+ ]
514+ conv1d_weight = self .cp .get_parameter_local_cp (
515+ self .conv1d .weight , dim = 0 , split_size_or_sections = qkv_channels_split_sections
516+ )
517+ conv1d_bias = (
518+ self .cp .get_parameter_local_cp (
519+ self .conv1d .bias , dim = 0 , split_size_or_sections = qkv_channels_split_sections
520+ )
521+ if self .conv_bias
522+ else None
523+ )
524+ if (causal_conv1d_fn is None ) or self .config .deterministic_mode :
525+ conv_out = F .conv1d (
526+ input = qkv ,
527+ weight = conv1d_weight ,
528+ bias = conv1d_bias ,
529+ stride = self .conv1d .stride ,
530+ padding = self .conv1d .padding ,
531+ dilation = self .conv1d .dilation ,
532+ groups = self .conv_dim_local_tp // self .cp_size ,
533+ )
534+ qkv = self .act_fn (conv_out [..., :seq_len ])
535+ else :
536+ assert self .activation in ["silu" , "swish" ]
537+ qkv = causal_conv1d_fn (
538+ x = qkv ,
539+ weight = conv1d_weight .squeeze (1 ), # d, 1, w -> d, w
540+ bias = conv1d_bias ,
541+ activation = self .activation ,
542+ )
543+ qkv = qkv .transpose (1 , 2 ) # b, d, s -> b, s, d
544+
545+ return qkv
546+
470547 @jit_fuser
471548 def _apply_gated_norm (self , x , gate ):
472549 # Output Norm
@@ -564,6 +641,17 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None, tp_gr
564641 return sharded_state_dict
565642
566643
644+ def _unpack_sequence (x , cu_seqlens , dim = 1 ):
645+ unpacked_x = []
646+ num_seqs = cu_seqlens .shape [0 ] - 1
647+ for i in range (num_seqs ):
648+ idx_start = cu_seqlens [i ].item ()
649+ idx_end = cu_seqlens [i + 1 ].item ()
650+ chunked_index = [slice (None )] * dim + [slice (idx_start , idx_end )]
651+ unpacked_x .append (x [chunked_index ])
652+ return unpacked_x
653+
654+
567655def _split_tensor_factory (
568656 orig_sh_ten : ShardedTensor , split_sections : List [int ], split_names : List [str ], split_dim : int
569657) -> ShardedTensorFactory :
0 commit comments