@@ -508,12 +508,28 @@ class GeneralModelForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
508
508
_embed_cls = None
509
509
_rotary_emb_cls = None
510
510
_norm_cls = "rms_norm"
511
+ _mtp_layer_pipe_cls = None
512
+ _embedding_pipe_cls = None
513
+ _decoder_layer_pipe_cls = None
514
+ _criterion_pipe_cls = None
515
+ _lmhead_pipe_cls = None
516
+ _rms_norm_pipe_cls = None
511
517
512
518
def __init__ (self , config : PretrainedConfig , ** kwargs ):
513
519
# dynamic inherit DecoderLayer
514
520
if self ._decoder_layer_cls is None :
515
521
raise ValueError ("_decoder_layer_cls must be set before init." )
516
- DecoderLayerPipe = make_decoder_layer_pipe (self ._decoder_layer_cls )
522
+
523
+ EmbeddingPipeCls = self ._embedding_pipe_cls if self ._embedding_pipe_cls is not None else Embedding
524
+
525
+ if self ._decoder_layer_pipe_cls is None :
526
+ DecoderLayerPipe = make_decoder_layer_pipe (self ._decoder_layer_cls )
527
+ else :
528
+ DecoderLayerPipe = self ._decoder_layer_pipe_cls
529
+
530
+ LMHeadPipeCls = self ._lmhead_pipe_cls if self ._lmhead_pipe_cls is not None else LMHeadPipe
531
+ MTPLayerPipeCls = self ._mtp_layer_pipe_cls if self ._mtp_layer_pipe_cls is not None else None
532
+ RMSNormPipeCls = self ._rms_norm_pipe_cls if self ._rms_norm_pipe_cls is not None else RMSNormPipe
517
533
518
534
new_initializer_range = math .sqrt (0.3333 / config .hidden_size )
519
535
logger .info (f"change initializer-range from { config .initializer_range } to { new_initializer_range } " )
@@ -560,7 +576,7 @@ def __init__(self, config: PretrainedConfig, **kwargs):
560
576
else :
561
577
self .add_sequential_layer (
562
578
LayerDesc (
563
- EmbeddingPipe , config = config , embed_cls = self ._embed_cls , rotary_emb_cls = self ._rotary_emb_cls
579
+ EmbeddingPipeCls , config = config , embed_cls = self ._embed_cls , rotary_emb_cls = self ._rotary_emb_cls
564
580
),
565
581
"model" ,
566
582
)
@@ -574,6 +590,12 @@ def __init__(self, config: PretrainedConfig, **kwargs):
574
590
),
575
591
f"model.layers.{ i } " ,
576
592
)
593
+ for i in range (config .num_nextn_predict_layers ):
594
+ if MTPLayerPipeCls is not None :
595
+ self .add_sequential_layer (
596
+ LayerDesc (MTPLayerPipeCls , config = config , layer_idx = config .num_hidden_layers + i ),
597
+ f"model.layers.{ config .num_hidden_layers + i } " ,
598
+ )
577
599
for i in range (config .add_tail_layers ):
578
600
self .add_sequential_layer (
579
601
LayerDesc (
@@ -583,22 +605,22 @@ def __init__(self, config: PretrainedConfig, **kwargs):
583
605
)
584
606
585
607
self .add_sequential_layer (
586
- LayerDesc (RMSNormPipe if self ._norm_cls == "rms_norm" else LayerNormPipe , config = config ),
608
+ LayerDesc (RMSNormPipeCls if self ._norm_cls == "rms_norm" else LayerNormPipe , config = config ),
587
609
"model.norm" ,
588
610
)
589
611
590
612
if config .tie_word_embeddings :
591
613
self .add_sequential_layer (
592
614
SharedLayerDesc (
593
615
"model_shared_weight" ,
594
- LMHeadPipe ,
616
+ LMHeadPipeCls ,
595
617
shared_weight_attr = "embedding_weight" ,
596
618
config = config ,
597
619
),
598
620
"lm_head" ,
599
621
)
600
622
else :
601
- self .add_sequential_layer (LayerDesc (LMHeadPipe , config = config ), "lm_head" )
623
+ self .add_sequential_layer (LayerDesc (LMHeadPipeCls , config = config ), "lm_head" )
602
624
recompute_interval = 0
603
625
604
626
seg_method = config .pp_seg_method if hasattr (config , "pp_seg_method" ) else "layer:DecoderLayer|EmptyLayer"
@@ -631,10 +653,12 @@ def __init__(self, config: PretrainedConfig, **kwargs):
631
653
)
632
654
633
655
def get_loss_fn (self , config ):
656
+ CriterionPipeCls = self ._criterion_pipe_cls if self ._criterion_pipe_cls is not None else CriterionLayerPipe
657
+
634
658
if config .get ("dpo_config" , None ) is not None :
635
- loss_fn = CriterionLayerPipe (config , use_infohub = True )
659
+ loss_fn = CriterionPipeCls (config , use_infohub = True )
636
660
else :
637
- loss_fn = CriterionLayerPipe (config )
661
+ loss_fn = CriterionPipeCls (config )
638
662
639
663
return loss_fn
640
664
0 commit comments