21
21
22
22
import numpy as np
23
23
import paddle
24
+ import paddle .distributed .fleet .meta_parallel as mpu
24
25
import paddle .nn .functional as F
25
26
from paddle import Tensor , nn
26
27
from paddle .distributed import fleet
@@ -296,19 +297,19 @@ def __init__(self, config):
296
297
self .intermediate_size = config .intermediate_size
297
298
298
299
if config .tensor_parallel_degree > 1 :
299
- self .gate_proj = fleet . meta_parallel .ColumnParallelLinear (
300
+ self .gate_proj = mpu .ColumnParallelLinear (
300
301
self .hidden_size ,
301
302
self .intermediate_size ,
302
303
gather_output = False ,
303
304
has_bias = False ,
304
305
)
305
- self .down_proj = fleet . meta_parallel .RowParallelLinear (
306
+ self .down_proj = mpu .RowParallelLinear (
306
307
self .intermediate_size ,
307
308
self .hidden_size ,
308
309
input_is_parallel = True ,
309
310
has_bias = False ,
310
311
)
311
- self .up_proj = fleet . meta_parallel .ColumnParallelLinear (
312
+ self .up_proj = mpu .ColumnParallelLinear (
312
313
self .hidden_size ,
313
314
self .intermediate_size ,
314
315
gather_output = False ,
@@ -339,19 +340,19 @@ def __init__(self, config):
339
340
self .num_heads = self .num_heads // config .tensor_parallel_degree
340
341
341
342
if config .tensor_parallel_degree > 1 :
342
- self .q_proj = fleet . meta_parallel .ColumnParallelLinear (
343
+ self .q_proj = mpu .ColumnParallelLinear (
343
344
self .hidden_size ,
344
345
self .hidden_size ,
345
346
has_bias = False ,
346
347
gather_output = False ,
347
348
)
348
- self .k_proj = fleet . meta_parallel .ColumnParallelLinear (
349
+ self .k_proj = mpu .ColumnParallelLinear (
349
350
self .hidden_size ,
350
351
self .hidden_size ,
351
352
has_bias = False ,
352
353
gather_output = False ,
353
354
)
354
- self .v_proj = fleet . meta_parallel .ColumnParallelLinear (
355
+ self .v_proj = mpu .ColumnParallelLinear (
355
356
self .hidden_size ,
356
357
self .hidden_size ,
357
358
has_bias = False ,
@@ -375,7 +376,7 @@ def __init__(self, config):
375
376
)
376
377
377
378
if config .tensor_parallel_degree > 1 :
378
- self .o_proj = fleet . meta_parallel .RowParallelLinear (
379
+ self .o_proj = mpu .RowParallelLinear (
379
380
self .hidden_size ,
380
381
self .hidden_size ,
381
382
has_bias = False ,
@@ -581,7 +582,17 @@ def get_tensor_parallel_split_mappings(num_layers):
581
582
582
583
def _init_weights (self , layer ):
583
584
"""Initialization hook"""
584
- if isinstance (layer , (nn .Linear , nn .Embedding )):
585
+ if isinstance (
586
+ layer ,
587
+ (
588
+ nn .Linear ,
589
+ nn .Embedding ,
590
+ mpu .VocabParallelEmbedding ,
591
+ mpu .ColumnParallelLinear ,
592
+ mpu .RowParallelLinear ,
593
+ LlamaLMHead ,
594
+ ),
595
+ ):
585
596
# In the dygraph mode, use the `set_value` to reset the parameter directly,
586
597
# and reset the `state_dict` to update parameter in static mode.
587
598
if isinstance (layer .weight , paddle .Tensor ):
@@ -594,6 +605,16 @@ def _init_weights(self, layer):
594
605
shape = layer .weight .shape ,
595
606
)
596
607
)
608
+ # Layer.apply is DFS https://github.com/PaddlePaddle/Paddle/blob/a6f5021fcc58b21f4414bae6bf4731ef6971582c/python/paddle/nn/layer/layers.py#L527-L530
609
+ # sublayer is init first
610
+ # scale RowParallelLinear weight
611
+ with paddle .no_grad ():
612
+ if isinstance (layer , LlamaMLP ):
613
+ factor = 1 / math .sqrt (2 * self .config .num_hidden_layers )
614
+ layer .down_proj .weight .scale_ (factor )
615
+ if isinstance (layer , LlamaAttention ):
616
+ factor = 1 / math .sqrt (2 * self .config .num_hidden_layers )
617
+ layer .o_proj .weight .scale_ (factor )
597
618
598
619
599
620
@register_base_model
@@ -610,7 +631,7 @@ def __init__(self, config: LlamaConfig):
610
631
self .hidden_size = config .hidden_size
611
632
612
633
if config .tensor_parallel_degree > 1 :
613
- self .embed_tokens = fleet . meta_parallel .VocabParallelEmbedding (
634
+ self .embed_tokens = mpu .VocabParallelEmbedding (
614
635
self .vocab_size ,
615
636
self .hidden_size ,
616
637
weight_attr = paddle .ParamAttr (initializer = nn .initializer .XavierNormal ()),
@@ -800,7 +821,7 @@ def __init__(self, config):
800
821
self .enable_parallel_cross_entropy = config .tensor_parallel_degree > 1 and config .tensor_parallel_output
801
822
802
823
if self .enable_parallel_cross_entropy : # and False: # and lm_head is distributed
803
- self .loss_func = fleet . meta_parallel .ParallelCrossEntropy (ignore_index = self .ignore_index )
824
+ self .loss_func = mpu .ParallelCrossEntropy (ignore_index = self .ignore_index )
804
825
else :
805
826
self .loss_func = paddle .nn .CrossEntropyLoss (reduction = "none" , ignore_index = self .ignore_index )
806
827
0 commit comments