@@ -293,29 +293,64 @@ def forward(
293293 raise NotImplementedError ("GDN does not support inference for now." )
294294
295295 if packed_seq_params is not None :
296- # TODO: support packed sequence
297- raise NotImplementedError ("GDN does not support packed sequence for now." )
296+ assert batch == 1 , "Packed sequence expects batch dimension to be 1"
297+ # Prefer cu_seqlens_q_padded if available, otherwise use cu_seqlens_q
298+ cu_seqlens_q = packed_seq_params .cu_seqlens_q_padded or packed_seq_params .cu_seqlens_q
299+ # Prefer cu_seqlens_kv_padded if available, otherwise use cu_seqlens_kv
300+ cu_seqlens_kv = (
301+ packed_seq_params .cu_seqlens_kv_padded or packed_seq_params .cu_seqlens_kv
302+ )
303+ assert torch .equal (cu_seqlens_q , cu_seqlens_kv ), (
304+ "Currently only support cu_seqlens_q equals to cu_seqlens_kv, "
305+ f"but got { cu_seqlens_q = } and { cu_seqlens_kv = } "
306+ )
307+ num_packed_seqs = cu_seqlens_q .shape [0 ] - 1
308+ assert num_packed_seqs > 0 , (
309+ "Number of packed sequences must be greater than 0, "
310+ f"but got { cu_seqlens_q = } and { cu_seqlens_kv = } "
311+ )
298312
299313 # Input projection
300314 nvtx_range_push (suffix = "in_proj" )
301315 qkvzba , _ = self .in_proj (hidden_states )
302316 nvtx_range_pop (suffix = "in_proj" )
303317
304318 # CP All to All: CP to HP
305- qkvzba = tensor_a2a_cp2hp (
306- qkvzba ,
307- seq_dim = 0 ,
308- head_dim = - 1 ,
309- cp_group = self .pg_collection .cp ,
310- split_sections = [
311- self .qk_dim_local_tp ,
312- self .qk_dim_local_tp ,
313- self .v_dim_local_tp ,
314- self .v_dim_local_tp ,
315- self .num_value_heads // self .tp_size ,
316- self .num_value_heads // self .tp_size ,
317- ],
318- )
319+ if packed_seq_params is not None :
320+ unpacked_qkvzba = _unpack_sequence (qkvzba , cu_seqlens_q // self .cp_size , dim = 0 )
321+ outputs = []
322+ for qkvzba_i in unpacked_qkvzba :
323+ qkvzba_i = tensor_a2a_cp2hp (
324+ qkvzba_i ,
325+ seq_dim = 0 ,
326+ head_dim = - 1 ,
327+ cp_group = self .pg_collection .cp ,
328+ split_sections = [
329+ self .qk_dim_local_tp ,
330+ self .qk_dim_local_tp ,
331+ self .v_dim_local_tp ,
332+ self .v_dim_local_tp ,
333+ self .num_value_heads // self .tp_size ,
334+ self .num_value_heads // self .tp_size ,
335+ ],
336+ )
337+ outputs .append (qkvzba_i )
338+ qkvzba = torch .cat (outputs , dim = 0 )
339+ else :
340+ qkvzba = tensor_a2a_cp2hp (
341+ qkvzba ,
342+ seq_dim = 0 ,
343+ head_dim = - 1 ,
344+ cp_group = self .pg_collection .cp ,
345+ split_sections = [
346+ self .qk_dim_local_tp ,
347+ self .qk_dim_local_tp ,
348+ self .v_dim_local_tp ,
349+ self .v_dim_local_tp ,
350+ self .num_value_heads // self .tp_size ,
351+ self .num_value_heads // self .tp_size ,
352+ ],
353+ )
319354
320355 # Transpose: s b x --> b s x
321356 # From sbhd to bshd format
@@ -337,51 +372,18 @@ def forward(
337372 alpha = alpha .reshape (batch , seq_len , - 1 )
338373
339374 # Convolution on qkv
340- qkv = qkv .transpose (1 , 2 ).contiguous () # b, s, d -> b, d, s
341375 nvtx_range_push (suffix = "conv1d" )
342- qkv_channels_split_sections = [
343- self .qk_dim_local_tp ,
344- self .qk_dim_local_tp ,
345- self .v_dim_local_tp ,
346- ]
347- conv1d_weight = get_parameter_local_cp (
348- self .conv1d .weight ,
349- dim = 0 ,
350- cp_group = self .pg_collection .cp ,
351- split_sections = qkv_channels_split_sections ,
352- )
353- conv1d_bias = (
354- get_parameter_local_cp (
355- self .conv1d .bias ,
356- dim = 0 ,
357- cp_group = self .pg_collection .cp ,
358- split_sections = qkv_channels_split_sections ,
359- )
360- if self .conv_bias
361- else None
362- )
363- if (causal_conv1d_fn is None ) or self .config .deterministic_mode :
364- conv_out = F .conv1d (
365- input = qkv ,
366- weight = conv1d_weight ,
367- bias = conv1d_bias ,
368- stride = self .conv1d .stride ,
369- padding = self .conv1d .padding ,
370- dilation = self .conv1d .dilation ,
371- groups = self .conv_dim_local_tp // self .cp_size ,
372- )
373- qkv = self .act_fn (conv_out [..., :seq_len ])
376+ if packed_seq_params is not None :
377+ unpacked_qkv = _unpack_sequence (qkv , cu_seqlens_q )
378+ outputs = []
379+ for qkv_i in unpacked_qkv :
380+ qkv_i = self ._conv1d_on_qkv (qkv_i )
381+ outputs .append (qkv_i )
382+ qkv = torch .cat (outputs , dim = 1 )
374383 else :
375- assert self .activation in ["silu" , "swish" ]
376- qkv = causal_conv1d_fn (
377- x = qkv ,
378- weight = conv1d_weight .squeeze (1 ), # d, 1, w -> d, w
379- bias = conv1d_bias ,
380- activation = self .activation ,
381- )
384+ qkv = self ._conv1d_on_qkv (qkv )
382385 nvtx_range_pop (suffix = "conv1d" )
383386 # Split qkv into query, key, and value
384- qkv = qkv .transpose (1 , 2 ) # b, d, s -> b, s, d
385387 query , key , value = torch .split (
386388 qkv ,
387389 [
@@ -422,18 +424,36 @@ def forward(
422424
423425 nvtx_range_push (suffix = "gated_delta_rule" )
424426 if self .config .deterministic_mode :
425- core_attn_out , last_recurrent_state = torch_chunk_gated_delta_rule (
426- query ,
427- key ,
428- value ,
429- g = g ,
430- beta = beta ,
431- initial_state = None ,
432- output_final_state = False ,
433- use_qk_l2norm_in_kernel = False ,
434- )
427+ gated_delta_rule_fn = torch_chunk_gated_delta_rule
428+ else :
429+ gated_delta_rule_fn = chunk_gated_delta_rule
430+
431+ if packed_seq_params is not None :
432+ # Packed sequence forward pass (THD format)
433+ query = _unpack_sequence (query , cu_seqlens_q )
434+ key = _unpack_sequence (key , cu_seqlens_kv )
435+ value = _unpack_sequence (value , cu_seqlens_kv )
436+ g = _unpack_sequence (g , cu_seqlens_q )
437+ beta = _unpack_sequence (beta , cu_seqlens_q )
438+
439+ outputs = []
440+ for i , (q_i , k_i , v_i , g_i , beta_i ) in enumerate (zip (query , key , value , g , beta )):
441+ out_i , last_recurrent_state = gated_delta_rule_fn (
442+ q_i ,
443+ k_i ,
444+ v_i ,
445+ g = g_i ,
446+ beta = beta_i ,
447+ initial_state = None ,
448+ output_final_state = False ,
449+ use_qk_l2norm_in_kernel = False ,
450+ )
451+ outputs .append (out_i )
452+
453+ core_attn_out = torch .cat (outputs , dim = 1 )
435454 else :
436- core_attn_out , last_recurrent_state = chunk_gated_delta_rule (
455+ # Regular forward pass (BSHD format)
456+ core_attn_out , last_recurrent_state = gated_delta_rule_fn (
437457 query ,
438458 key ,
439459 value ,
@@ -456,9 +476,19 @@ def forward(
456476 norm_out = norm_out .transpose (0 , 1 ).contiguous ()
457477
458478 # CP all to all: HP to CP
459- norm_out = tensor_a2a_hp2cp (
460- norm_out , seq_dim = 0 , head_dim = - 1 , cp_group = self .pg_collection .cp
461- )
479+ if packed_seq_params is not None :
480+ unpacked_norm_out = _unpack_sequence (norm_out , cu_seqlens_q , dim = 0 )
481+ outputs = []
482+ for norm_out_i in unpacked_norm_out :
483+ norm_out_i = tensor_a2a_hp2cp (
484+ norm_out_i , seq_dim = 0 , head_dim = - 1 , cp_group = self .pg_collection .cp
485+ )
486+ outputs .append (norm_out_i )
487+ norm_out = torch .cat (outputs , dim = 0 )
488+ else :
489+ norm_out = tensor_a2a_hp2cp (
490+ norm_out , seq_dim = 0 , head_dim = - 1 , cp_group = self .pg_collection .cp
491+ )
462492
463493 # Output projection
464494 nvtx_range_push (suffix = "out_proj" )
@@ -467,6 +497,53 @@ def forward(
467497
468498 return out , out_bias
469499
500+ def _conv1d_on_qkv (self , qkv ):
501+ qkv = qkv .transpose (1 , 2 ).contiguous () # b, s, d -> b, d, s
502+ seq_len = qkv .shape [2 ]
503+ qkv_channels_split_sections = [
504+ self .qk_dim_local_tp ,
505+ self .qk_dim_local_tp ,
506+ self .v_dim_local_tp ,
507+ ]
508+ conv1d_weight = get_parameter_local_cp (
509+ self .conv1d .weight ,
510+ dim = 0 ,
511+ cp_group = self .pg_collection .cp ,
512+ split_sections = qkv_channels_split_sections ,
513+ )
514+ conv1d_bias = (
515+ get_parameter_local_cp (
516+ self .conv1d .bias ,
517+ dim = 0 ,
518+ cp_group = self .pg_collection .cp ,
519+ split_sections = qkv_channels_split_sections ,
520+ )
521+ if self .conv_bias
522+ else None
523+ )
524+ if (causal_conv1d_fn is None ) or self .config .deterministic_mode :
525+ conv_out = F .conv1d (
526+ input = qkv ,
527+ weight = conv1d_weight ,
528+ bias = conv1d_bias ,
529+ stride = self .conv1d .stride ,
530+ padding = self .conv1d .padding ,
531+ dilation = self .conv1d .dilation ,
532+ groups = self .conv_dim_local_tp // self .cp_size ,
533+ )
534+ qkv = self .act_fn (conv_out [..., :seq_len ])
535+ else :
536+ assert self .activation in ["silu" , "swish" ]
537+ qkv = causal_conv1d_fn (
538+ x = qkv ,
539+ weight = conv1d_weight .squeeze (1 ), # d, 1, w -> d, w
540+ bias = conv1d_bias ,
541+ activation = self .activation ,
542+ )
543+ qkv = qkv .transpose (1 , 2 ) # b, d, s -> b, s, d
544+
545+ return qkv
546+
470547 @jit_fuser
471548 def _apply_gated_norm (self , x , gate ):
472549 # Output Norm
@@ -564,6 +641,17 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None, tp_gr
564641 return sharded_state_dict
565642
566643
644+ def _unpack_sequence (x , cu_seqlens , dim = 1 ):
645+ unpacked_x = []
646+ num_seqs = cu_seqlens .shape [0 ] - 1
647+ for i in range (num_seqs ):
648+ idx_start = cu_seqlens [i ].item ()
649+ idx_end = cu_seqlens [i + 1 ].item ()
650+ chunked_index = [slice (None )] * dim + [slice (idx_start , idx_end )]
651+ unpacked_x .append (x [chunked_index ])
652+ return unpacked_x
653+
654+
567655####################
568656# Sharded state dict utilities
569657####################
0 commit comments