@@ -254,7 +254,7 @@ def forward(self, args):
254
254
emb = self .embed_tokens (input_ids ).astype (self .embed_tokens .weight .dtype )
255
255
if position_ids is None and not self .config .fuse_rope :
256
256
position_ids = (
257
- paddle .range (
257
+ paddle .arange (
258
258
0 ,
259
259
input_ids .shape [1 ],
260
260
dtype = "int64" ,
@@ -410,13 +410,13 @@ def forward(self, args):
410
410
max_seq_len = hidden_states .shape [0 ] * self .config .tensor_parallel_degree
411
411
if attention_mask is None :
412
412
tgt_mask = None
413
- attn_mask_start_row_indices = None
413
+ attn_mask_startend_row_indices = None
414
414
elif attention_mask .dtype == paddle .int32 :
415
415
tgt_mask = None
416
- attn_mask_start_row_indices = attention_mask [:, :, :max_seq_len ]
416
+ attn_mask_startend_row_indices = attention_mask [:, :, :max_seq_len ]
417
417
else :
418
418
tgt_mask = attention_mask [:, :, :max_seq_len , :max_seq_len ]
419
- attn_mask_start_row_indices = None
419
+ attn_mask_startend_row_indices = None
420
420
assert len (tgt_mask .shape ) == 4 , f"Attention mask should be 4D tensor, but got { tgt_mask .shape } ."
421
421
422
422
position_ids_decoder = None
@@ -436,7 +436,7 @@ def forward(self, args):
436
436
self ,
437
437
hidden_states ,
438
438
attention_mask = tgt_mask ,
439
- attn_mask_start_row_indices = attn_mask_start_row_indices ,
439
+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
440
440
position_ids = position_ids_decoder ,
441
441
position_embeddings = tuple_position_embeddings ,
442
442
use_reentrant = self .config .recompute_use_reentrant ,
@@ -446,7 +446,7 @@ def forward(self, args):
446
446
self ,
447
447
hidden_states = hidden_states ,
448
448
attention_mask = tgt_mask ,
449
- attn_mask_start_row_indices = attn_mask_start_row_indices ,
449
+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
450
450
position_ids = position_ids_decoder ,
451
451
position_embeddings = tuple_position_embeddings ,
452
452
)
@@ -492,36 +492,44 @@ def forward(self, logits, labels):
492
492
493
493
494
494
class GeneralModelForCausalLMPipe (PipelinePretrainedModel , PipelineLayer ):
495
+ _decoder_layer_cls = None
496
+ _get_tensor_parallel_mappings = None
497
+ _init_weights = None
498
+ _keep_in_fp32_modules = None
495
499
_tied_weights_keys = ["lm_head.weight" ]
500
+ config_class = PretrainedConfig
501
+ transpose_weight_keys = None
496
502
497
- def __init__ (self , config : PretrainedConfig , decoder_layer , ** kwargs ):
503
+ def __init__ (self , config : PretrainedConfig , ** kwargs ):
498
504
# dynamic inherit DecoderLayer
499
- DecoderLayerPipe = make_decoder_layer_pipe (decoder_layer )
505
+ if self ._decoder_layer_cls is None :
506
+ raise ValueError ("_decoder_layer_cls must be set before init." )
507
+ DecoderLayerPipe = make_decoder_layer_pipe (self ._decoder_layer_cls )
508
+
500
509
new_initializer_range = math .sqrt (0.3333 / config .hidden_size )
501
510
logger .info (f"change initializer-range from { config .initializer_range } to { new_initializer_range } " )
502
511
config .initializer_range = new_initializer_range
503
512
504
- if config .get ("moe_group" , "" ) == "mp" :
513
+ moe_group = config .get ("moe_group" , "dummy" )
514
+ if moe_group == "mp" :
505
515
assert config .sequence_parallel
506
516
507
- if config . moe_group in {"mp" , "model" , "tp" , "mpdp" }:
517
+ if moe_group in {"mp" , "model" , "tp" , "mpdp" }:
508
518
assert config .sequence_parallel
509
- logger .info (f"disable FFN tensor model parallel, moe-group={ config . moe_group } " )
519
+ logger .info (f"disable FFN tensor model parallel, moe-group={ moe_group } " )
510
520
config .disable_ffn_model_parallel = True
511
521
512
- config .moe_group_origin = config . moe_group
513
- config .moe_group = _parse_moe_group (config . moe_group )
522
+ config .moe_group_origin = moe_group
523
+ config .moe_group = _parse_moe_group (moe_group )
514
524
config .moe_world_size = dist .get_world_size (config .moe_group )
515
525
if config .moe_world_size < 0 :
516
526
config .moe_world_size = 1
517
527
config .moe_rank = dist .get_rank (config .moe_group )
518
528
519
529
self .config = config
520
-
521
530
hcg = get_hcg ()
522
531
tensor_parallel_degree = max (hcg .get_model_parallel_world_size (), 1 )
523
532
tensor_parallel_rank = max (hcg .get_model_parallel_rank (), 0 )
524
-
525
533
config .tensor_parallel_degree = tensor_parallel_degree
526
534
config .tensor_parallel_rank = tensor_parallel_rank
527
535
@@ -607,7 +615,7 @@ def __init__(self, config: PretrainedConfig, decoder_layer, **kwargs):
607
615
)
608
616
609
617
def get_loss_fn (self , config ):
610
- if config .dpo_config is not None :
618
+ if config .get ( " dpo_config" , None ) is not None :
611
619
loss_fn = CriterionLayerPipe (config , use_infohub = True )
612
620
else :
613
621
loss_fn = CriterionLayerPipe (config )
@@ -633,7 +641,7 @@ def register_cls_attr(cls, config_class=None, pretrained_model_class=None):
633
641
def _prepare_pipeline_inputs_func (cls , inputs ):
634
642
first_stage_keys = [
635
643
"input_ids" ,
636
- "attn_mask_start_row_indices " ,
644
+ "attn_mask_startend_row_indices " ,
637
645
"position_ids" ,
638
646
"nbatch_pack_offset" ,
639
647
]
0 commit comments