@@ -343,8 +343,19 @@ def __init__(self, *args, **kwargs):
343
343
344
344
def forward (self , args ):
345
345
hidden_states , _ , _ , _ , _ = parse_args (args )
346
- hidden_states = super ().forward (hidden_states )
347
- return hidden_states
346
+
347
+ if self .config .num_nextn_predict_layers > 0 :
348
+ hidden_states_list = paddle .split (hidden_states , self .config .num_nextn_predict_layers + 1 , axis = - 1 )
349
+ hidden_states = hidden_states_list [0 ]
350
+ hidden_states_mtp = hidden_states_list [- self .config .num_nextn_predict_layers :]
351
+
352
+ output_list = [super ().forward (hidden_states )]
353
+ for hidden_states in hidden_states_mtp :
354
+ output_list .append (super ().forward (hidden_states ))
355
+ return output_list
356
+ else :
357
+ hidden_states = super ().forward (hidden_states )
358
+ return hidden_states
348
359
349
360
350
361
class LayerNormPipe (LayerNorm ):
@@ -389,6 +400,12 @@ def forward(self, args):
389
400
[batch_size, sequence_length, vocab_size]
390
401
representing unnormalized log probabilities for each token
391
402
"""
403
+ if self .config .num_nextn_predict_layers > 0 :
404
+ logits = []
405
+ for _hidden_states in args :
406
+ logits .append (super ().forward (_hidden_states ))
407
+ return logits
408
+
392
409
hidden_states , _ , _ , _ , _ = parse_args (args )
393
410
logits = super ().forward (hidden_states )
394
411
return logits
@@ -507,12 +524,25 @@ class GeneralModelForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
507
524
transpose_weight_keys = None
508
525
_embed_cls = None
509
526
_rotary_emb_cls = None
527
+ _mtp_layer_pipe_cls = None
528
+ _embedding_pipe_cls = None
529
+ _decoder_layer_pipe_cls = None
530
+ _criterion_pipe_cls = None
531
+ _lmhead_pipe_cls = None
510
532
511
533
def __init__ (self , config : PretrainedConfig , ** kwargs ):
512
534
# dynamic inherit DecoderLayer
513
535
if self ._decoder_layer_cls is None :
514
536
raise ValueError ("_decoder_layer_cls must be set before init." )
515
- DecoderLayerPipe = make_decoder_layer_pipe (self ._decoder_layer_cls )
537
+
538
+ EmbeddingPipeCls = self ._embedding_pipe_cls if self ._embedding_pipe_cls is not None else Embedding
539
+
540
+ if self ._decoder_layer_pipe_cls is None :
541
+ DecoderLayerPipe = make_decoder_layer_pipe (self ._decoder_layer_cls )
542
+ else :
543
+ DecoderLayerPipe = self ._decoder_layer_pipe_cls
544
+
545
+ LMHeadPipeCls = self ._lmhead_pipe_cls if self ._lmhead_pipe_cls is not None else LMHeadPipe
516
546
517
547
new_initializer_range = math .sqrt (0.3333 / config .hidden_size )
518
548
logger .info (f"change initializer-range from { config .initializer_range } to { new_initializer_range } " )
@@ -559,7 +589,7 @@ def __init__(self, config: PretrainedConfig, **kwargs):
559
589
else :
560
590
self .add_sequential_layer (
561
591
LayerDesc (
562
- EmbeddingPipe , config = config , embed_cls = self ._embed_cls , rotary_emb_cls = self ._rotary_emb_cls
592
+ EmbeddingPipeCls , config = config , embed_cls = self ._embed_cls , rotary_emb_cls = self ._rotary_emb_cls
563
593
),
564
594
"model" ,
565
595
)
@@ -573,6 +603,11 @@ def __init__(self, config: PretrainedConfig, **kwargs):
573
603
),
574
604
f"model.layers.{ i } " ,
575
605
)
606
+ for i in range (config .num_nextn_predict_layers ):
607
+ self .add_sequential_layer (
608
+ LayerDesc (self ._mtp_layer_pipe_cls , config = config , layer_idx = config .num_hidden_layers + i ),
609
+ f"model.layers.{ config .num_hidden_layers + i } " ,
610
+ )
576
611
for i in range (config .add_tail_layers ):
577
612
self .add_sequential_layer (
578
613
LayerDesc (
@@ -590,14 +625,14 @@ def __init__(self, config: PretrainedConfig, **kwargs):
590
625
self .add_sequential_layer (
591
626
SharedLayerDesc (
592
627
"model_shared_weight" ,
593
- LMHeadPipe ,
628
+ LMHeadPipeCls ,
594
629
shared_weight_attr = "embedding_weight" ,
595
630
config = config ,
596
631
),
597
632
"lm_head" ,
598
633
)
599
634
else :
600
- self .add_sequential_layer (LayerDesc (LMHeadPipe , config = config ), "lm_head" )
635
+ self .add_sequential_layer (LayerDesc (LMHeadPipeCls , config = config ), "lm_head" )
601
636
recompute_interval = 0
602
637
603
638
seg_method = config .pp_seg_method if hasattr (config , "pp_seg_method" ) else "layer:DecoderLayer|EmptyLayer"
@@ -630,10 +665,12 @@ def __init__(self, config: PretrainedConfig, **kwargs):
630
665
)
631
666
632
667
def get_loss_fn (self , config ):
668
+ CriterionPipeCls = self ._criterion_pipe_cls if self ._criterion_pipe_cls is not None else CriterionLayerPipe
669
+
633
670
if config .get ("dpo_config" , None ) is not None :
634
- loss_fn = CriterionLayerPipe (config , use_infohub = True )
671
+ loss_fn = CriterionPipeCls (config , use_infohub = True )
635
672
else :
636
- loss_fn = CriterionLayerPipe (config )
673
+ loss_fn = CriterionPipeCls (config )
637
674
638
675
return loss_fn
639
676
0 commit comments