Skip to content

Commit ecb2f66

Browse files
authored
fix llama pretrain init. (#6116)
1 parent a7a7251 commit ecb2f66

File tree

5 files changed

+49
-22
lines changed

5 files changed

+49
-22
lines changed

examples/language_model/llama/modeling_pp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ class LlamaForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
199199
config_class = LlamaConfig
200200

201201
_get_tensor_parallel_mappings = LlamaPretrainedModel._get_tensor_parallel_mappings
202+
_init_weights = LlamaPretrainedModel._init_weights
202203

203204
# NO base_model_prefix !!!!
204205

@@ -258,5 +259,6 @@ def __init__(
258259
},
259260
num_virtual_pipeline_stages=virtual_pp_degree,
260261
)
262+
self.apply(self._init_weights)
261263
# DON'T init PipelinePretrainedModel
262264
# PipelinePretrainedModel.__init__(self.super(), config=config)

examples/language_model/llama/run_pretrain.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,19 +111,18 @@ class ModelArguments:
111111
default="llama", metadata={"help": "Only support for llama pre-training for now."}
112112
)
113113
model_name_or_path: str = field(
114-
default="gpt2-meidum-en",
114+
default="facebook/tiny-random-llama",
115115
metadata={
116116
"help": "Path to pretrained model or model identifier from https://paddlenlp.readthedocs.io/zh/latest/model_zoo/transformers.html"
117117
},
118118
)
119-
hidden_dropout_prob: float = field(default=0.1, metadata={"help": "The hidden dropout prob."})
120-
attention_probs_dropout_prob: float = field(default=0.1, metadata={"help": "The attention probs dropout prob."})
119+
tokenizer_name_or_path: Optional[str] = field(
120+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
121+
)
122+
121123
config_name: Optional[str] = field(
122124
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
123125
)
124-
tokenizer_name_or_path: Optional[str] = field(
125-
default="gpt2-meidum-en", metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
126-
)
127126
use_flash_attention: bool = field(
128127
default=False,
129128
metadata={"help": "use_flash_attention"},
@@ -202,7 +201,7 @@ def print_dataset(data, mode="train"):
202201
def build_dataset(index, name):
203202
dataset = GPTDataset(
204203
file_prefix=input_prefix,
205-
build_data_file=training_args.local_rank == 0,
204+
build_data_file=training_args.local_process_index == 0,
206205
micro_batch_size=training_args.per_device_train_batch_size
207206
if name == "train"
208207
else training_args.per_device_eval_batch_size,

model_zoo/gpt/run_pretrain_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def print_dataset(data, mode="train"):
195195
def build_dataset(index, name):
196196
dataset = GPTDataset(
197197
file_prefix=input_prefix,
198-
build_data_file=training_args.local_rank == 0,
198+
build_data_file=training_args.local_process_index == 0,
199199
micro_batch_size=training_args.per_device_train_batch_size
200200
if name == "train"
201201
else training_args.per_device_eval_batch_size,

paddlenlp/trainer/trainer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@
115115
except:
116116
mix_precision_utils = None
117117

118+
try:
119+
from paddle.io.dataloader.dataloader_iter import _DataLoaderIterBase
120+
except:
121+
from paddle.fluid.dataloader.dataloader_iter import _DataLoaderIterBase
122+
118123

119124
def paddlenlp_load(path, return_numpy=False):
120125
if return_numpy:
@@ -752,9 +757,9 @@ def train(
752757
for p in model._layers.parameters():
753758
if hasattr(p, "main_grad") and p.main_grad is not None:
754759
assert p.grad is None
755-
p.main_grad = p.main_grad.scale(1.0 / model.accumulate_steps)
760+
p.main_grad = p.main_grad.scale(1.0 / self.args.gradient_accumulation_steps)
756761
elif p.grad is not None:
757-
p.grad = p.grad.scale(1.0 / model.accumulate_steps)
762+
p.grad = p.grad.scale(1.0 / self.args.gradient_accumulation_steps)
758763

759764
# Optimizer step
760765
optimizer_was_run = True
@@ -1930,7 +1935,7 @@ def evaluation_loop(
19301935

19311936
if isinstance(dataloader, paddle.io.DataLoader):
19321937
batch_size = dataloader.batch_sampler.batch_size
1933-
elif isinstance(dataloader, paddle.fluid.dataloader.dataloader_iter._DataLoaderIterBase):
1938+
elif isinstance(dataloader, _DataLoaderIterBase):
19341939
# support for inner dataloader
19351940
batch_size = dataloader._batch_sampler.batch_size
19361941
# alias for inner dataloader
@@ -1942,7 +1947,7 @@ def evaluation_loop(
19421947
if max_eval_iters > 0:
19431948
# on eval limit steps
19441949
num_samples = batch_size * self.args.dataset_world_size * max_eval_iters
1945-
if isinstance(dataloader, paddle.fluid.dataloader.dataloader_iter._DataLoaderIterBase) and isinstance(
1950+
if isinstance(dataloader, _DataLoaderIterBase) and isinstance(
19461951
dataloader._batch_sampler, NlpDistributedBatchSampler
19471952
):
19481953
consumed_samples = (

paddlenlp/transformers/llama/modeling.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import numpy as np
2323
import paddle
24+
import paddle.distributed.fleet.meta_parallel as mpu
2425
import paddle.nn.functional as F
2526
from paddle import Tensor, nn
2627
from paddle.distributed import fleet
@@ -296,19 +297,19 @@ def __init__(self, config):
296297
self.intermediate_size = config.intermediate_size
297298

298299
if config.tensor_parallel_degree > 1:
299-
self.gate_proj = fleet.meta_parallel.ColumnParallelLinear(
300+
self.gate_proj = mpu.ColumnParallelLinear(
300301
self.hidden_size,
301302
self.intermediate_size,
302303
gather_output=False,
303304
has_bias=False,
304305
)
305-
self.down_proj = fleet.meta_parallel.RowParallelLinear(
306+
self.down_proj = mpu.RowParallelLinear(
306307
self.intermediate_size,
307308
self.hidden_size,
308309
input_is_parallel=True,
309310
has_bias=False,
310311
)
311-
self.up_proj = fleet.meta_parallel.ColumnParallelLinear(
312+
self.up_proj = mpu.ColumnParallelLinear(
312313
self.hidden_size,
313314
self.intermediate_size,
314315
gather_output=False,
@@ -339,19 +340,19 @@ def __init__(self, config):
339340
self.num_heads = self.num_heads // config.tensor_parallel_degree
340341

341342
if config.tensor_parallel_degree > 1:
342-
self.q_proj = fleet.meta_parallel.ColumnParallelLinear(
343+
self.q_proj = mpu.ColumnParallelLinear(
343344
self.hidden_size,
344345
self.hidden_size,
345346
has_bias=False,
346347
gather_output=False,
347348
)
348-
self.k_proj = fleet.meta_parallel.ColumnParallelLinear(
349+
self.k_proj = mpu.ColumnParallelLinear(
349350
self.hidden_size,
350351
self.hidden_size,
351352
has_bias=False,
352353
gather_output=False,
353354
)
354-
self.v_proj = fleet.meta_parallel.ColumnParallelLinear(
355+
self.v_proj = mpu.ColumnParallelLinear(
355356
self.hidden_size,
356357
self.hidden_size,
357358
has_bias=False,
@@ -375,7 +376,7 @@ def __init__(self, config):
375376
)
376377

377378
if config.tensor_parallel_degree > 1:
378-
self.o_proj = fleet.meta_parallel.RowParallelLinear(
379+
self.o_proj = mpu.RowParallelLinear(
379380
self.hidden_size,
380381
self.hidden_size,
381382
has_bias=False,
@@ -581,7 +582,17 @@ def get_tensor_parallel_split_mappings(num_layers):
581582

582583
def _init_weights(self, layer):
583584
"""Initialization hook"""
584-
if isinstance(layer, (nn.Linear, nn.Embedding)):
585+
if isinstance(
586+
layer,
587+
(
588+
nn.Linear,
589+
nn.Embedding,
590+
mpu.VocabParallelEmbedding,
591+
mpu.ColumnParallelLinear,
592+
mpu.RowParallelLinear,
593+
LlamaLMHead,
594+
),
595+
):
585596
# In the dygraph mode, use the `set_value` to reset the parameter directly,
586597
# and reset the `state_dict` to update parameter in static mode.
587598
if isinstance(layer.weight, paddle.Tensor):
@@ -594,6 +605,16 @@ def _init_weights(self, layer):
594605
shape=layer.weight.shape,
595606
)
596607
)
608+
# Layer.apply is DFS https://github.com/PaddlePaddle/Paddle/blob/a6f5021fcc58b21f4414bae6bf4731ef6971582c/python/paddle/nn/layer/layers.py#L527-L530
609+
# sublayer is init first
610+
# scale RowParallelLinear weight
611+
with paddle.no_grad():
612+
if isinstance(layer, LlamaMLP):
613+
factor = 1 / math.sqrt(2 * self.config.num_hidden_layers)
614+
layer.down_proj.weight.scale_(factor)
615+
if isinstance(layer, LlamaAttention):
616+
factor = 1 / math.sqrt(2 * self.config.num_hidden_layers)
617+
layer.o_proj.weight.scale_(factor)
597618

598619

599620
@register_base_model
@@ -610,7 +631,7 @@ def __init__(self, config: LlamaConfig):
610631
self.hidden_size = config.hidden_size
611632

612633
if config.tensor_parallel_degree > 1:
613-
self.embed_tokens = fleet.meta_parallel.VocabParallelEmbedding(
634+
self.embed_tokens = mpu.VocabParallelEmbedding(
614635
self.vocab_size,
615636
self.hidden_size,
616637
weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()),
@@ -800,7 +821,7 @@ def __init__(self, config):
800821
self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output
801822

802823
if self.enable_parallel_cross_entropy: # and False: # and lm_head is distributed
803-
self.loss_func = fleet.meta_parallel.ParallelCrossEntropy(ignore_index=self.ignore_index)
824+
self.loss_func = mpu.ParallelCrossEntropy(ignore_index=self.ignore_index)
804825
else:
805826
self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index)
806827

0 commit comments

Comments
 (0)