@@ -296,29 +296,71 @@ def forward(
296296 raise NotImplementedError ("GDN does not support inference for now." )
297297
298298 if packed_seq_params is not None :
299- # TODO: support packed sequence
300- raise NotImplementedError ("GDN does not support packed sequence for now." )
299+ assert batch == 1 , "Packed sequence expects batch dimension to be 1"
300+ assert (
301+ not self .config .deterministic_mode
302+ ), "Packed sequence does not support deterministic mode."
303+
304+ # Prefer cu_seqlens_q_padded if available, otherwise use cu_seqlens_q
305+ cu_seqlens_q = packed_seq_params .cu_seqlens_q_padded or packed_seq_params .cu_seqlens_q
306+ # Prefer cu_seqlens_kv_padded if available, otherwise use cu_seqlens_kv
307+ cu_seqlens_kv = (
308+ packed_seq_params .cu_seqlens_kv_padded or packed_seq_params .cu_seqlens_kv
309+ )
310+ assert torch .equal (cu_seqlens_q , cu_seqlens_kv ), (
311+ "Currently only support cu_seqlens_q equals to cu_seqlens_kv, "
312+ f"but got { cu_seqlens_q = } and { cu_seqlens_kv = } "
313+ )
314+ num_packed_seqs = cu_seqlens_q .shape [0 ] - 1
315+ assert num_packed_seqs > 0 , (
316+ "Number of packed sequences must be greater than 0, "
317+ f"but got { cu_seqlens_q = } and { cu_seqlens_kv = } "
318+ )
319+ else :
320+ cu_seqlens_q = None
321+ cu_seqlens_kv = None
301322
302323 # Input projection
303324 nvtx_range_push (suffix = "in_proj" )
304325 qkvzba , _ = self .in_proj (hidden_states )
305326 nvtx_range_pop (suffix = "in_proj" )
306327
307328 # CP All to All: CP to HP
308- qkvzba = tensor_a2a_cp2hp (
309- qkvzba ,
310- seq_dim = 0 ,
311- head_dim = - 1 ,
312- cp_group = self .pg_collection .cp ,
313- split_sections = [
314- self .qk_dim_local_tp ,
315- self .qk_dim_local_tp ,
316- self .v_dim_local_tp ,
317- self .v_dim_local_tp ,
318- self .num_value_heads // self .tp_size ,
319- self .num_value_heads // self .tp_size ,
320- ],
321- )
329+ if packed_seq_params is not None :
330+ unpacked_qkvzba = _unpack_sequence (qkvzba , cu_seqlens_q // self .cp_size , dim = 0 )
331+ outputs = []
332+ for qkvzba_i in unpacked_qkvzba :
333+ qkvzba_i = tensor_a2a_cp2hp (
334+ qkvzba_i ,
335+ seq_dim = 0 ,
336+ head_dim = - 1 ,
337+ cp_group = self .pg_collection .cp ,
338+ split_sections = [
339+ self .qk_dim_local_tp ,
340+ self .qk_dim_local_tp ,
341+ self .v_dim_local_tp ,
342+ self .v_dim_local_tp ,
343+ self .num_value_heads // self .tp_size ,
344+ self .num_value_heads // self .tp_size ,
345+ ],
346+ )
347+ outputs .append (qkvzba_i )
348+ qkvzba = torch .cat (outputs , dim = 0 )
349+ else :
350+ qkvzba = tensor_a2a_cp2hp (
351+ qkvzba ,
352+ seq_dim = 0 ,
353+ head_dim = - 1 ,
354+ cp_group = self .pg_collection .cp ,
355+ split_sections = [
356+ self .qk_dim_local_tp ,
357+ self .qk_dim_local_tp ,
358+ self .v_dim_local_tp ,
359+ self .v_dim_local_tp ,
360+ self .num_value_heads // self .tp_size ,
361+ self .num_value_heads // self .tp_size ,
362+ ],
363+ )
322364
323365 # Transpose: s b x --> b s x
324366 # From sbhd to bshd format
@@ -385,6 +427,7 @@ def forward(
385427 activation = self .activation ,
386428 initial_state = None ,
387429 output_final_state = False ,
430+ cu_seqlens = cu_seqlens_q ,
388431 )
389432 nvtx_range_pop (suffix = "conv1d" )
390433
@@ -440,6 +483,7 @@ def forward(
440483 initial_state = None ,
441484 output_final_state = False ,
442485 use_qk_l2norm_in_kernel = False ,
486+ cu_seqlens = cu_seqlens_q ,
443487 )
444488 nvtx_range_pop (suffix = "gated_delta_rule" )
445489
@@ -454,9 +498,19 @@ def forward(
454498 norm_out = norm_out .transpose (0 , 1 ).contiguous ()
455499
456500 # CP all to all: HP to CP
457- norm_out = tensor_a2a_hp2cp (
458- norm_out , seq_dim = 0 , head_dim = - 1 , cp_group = self .pg_collection .cp
459- )
501+ if packed_seq_params is not None :
502+ unpacked_norm_out = _unpack_sequence (norm_out , cu_seqlens_q , dim = 0 )
503+ outputs = []
504+ for norm_out_i in unpacked_norm_out :
505+ norm_out_i = tensor_a2a_hp2cp (
506+ norm_out_i , seq_dim = 0 , head_dim = - 1 , cp_group = self .pg_collection .cp
507+ )
508+ outputs .append (norm_out_i )
509+ norm_out = torch .cat (outputs , dim = 0 )
510+ else :
511+ norm_out = tensor_a2a_hp2cp (
512+ norm_out , seq_dim = 0 , head_dim = - 1 , cp_group = self .pg_collection .cp
513+ )
460514
461515 # Output projection
462516 nvtx_range_push (suffix = "out_proj" )
@@ -575,6 +629,17 @@ def _backward_out_proj(self):
575629 self .out_proj .backward_dw ()
576630
577631
632+ def _unpack_sequence (x , cu_seqlens , dim = 1 ):
633+ unpacked_x = []
634+ num_seqs = cu_seqlens .shape [0 ] - 1
635+ for i in range (num_seqs ):
636+ idx_start = cu_seqlens [i ].item ()
637+ idx_end = cu_seqlens [i + 1 ].item ()
638+ chunked_index = [slice (None )] * dim + [slice (idx_start , idx_end )]
639+ unpacked_x .append (x [chunked_index ])
640+ return unpacked_x
641+
642+
578643####################
579644# Sharded state dict utilities
580645####################
0 commit comments