Skip to content

Commit 77680b6

Browse files
committed
Merge commit 'refs/pull/6670/head' of github.com:PaddlePaddle/PaddleNLP into llama
2 parents dc32055 + 951ab83 commit 77680b6

File tree

6 files changed

+170
-31
lines changed

6 files changed

+170
-31
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@
156156
try:
157157
from paddle.io.dataloader.dataloader_iter import _DataLoaderIterBase
158158
except:
159-
from paddle.fluid.dataloader.dataloader_iter import _DataLoaderIterBase
159+
from paddle.base.dataloader.dataloader_iter import _DataLoaderIterBase
160160

161161

162162
def is_dp_group_support_in_group_sharded_parallel():
@@ -689,6 +689,22 @@ def train(
689689
# so, the trainable numel is a little bigger than real.
690690
logger.info(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)")
691691

692+
model.llama = paddle.jit.to_static(
693+
model.llama,
694+
input_spec=[
695+
paddle.static.InputSpec(name="input_ids", shape=[-1, -1], dtype="int64"), # input_ids
696+
None, # position_ids
697+
None, # attention_mask
698+
None, # inputs_embeds
699+
# paddle.static.InputSpec(name="labels", shape=[-1, -1], dtype="int64"), # labels
700+
False, # use_cache
701+
None, # past_key_values
702+
None, # output_attentions
703+
None, # output_hidden_states
704+
None, # return_dict
705+
],
706+
)
707+
paddle.base.core._set_prim_forward_blacklist("expand_v2")
692708
start_time = time.time()
693709
self._globalstep_last_start_time = time.time()
694710
self.state.epoch = 0

paddlenlp/transformers/llama/configuration.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,92 @@
6060
"use_recompute": False,
6161
"use_flash_attention": False,
6262
},
63+
"__internal_testing__/distributed-projection-llama-7b": {
64+
"hidden_size": 2048,
65+
"embedding_output_size": 4096,
66+
"initializer_range": 0.02,
67+
"intermediate_size": 5504,
68+
"max_position_embeddings": 2048,
69+
"model_type": "llama",
70+
"num_attention_heads": 16,
71+
"num_hidden_layers": 2,
72+
"rms_norm_eps": 1e-06,
73+
"vocab_size": 32000,
74+
"bos_token_id": 1,
75+
"eos_token_id": 2,
76+
"pad_token_id": 0,
77+
"use_cache": False,
78+
"use_recompute": False,
79+
"use_flash_attention": False,
80+
},
81+
"facebook/llama-7b": {
82+
"hidden_size": 4096,
83+
"initializer_range": 0.02,
84+
"intermediate_size": 11008,
85+
"max_position_embeddings": 2048,
86+
"model_type": "llama",
87+
"num_attention_heads": 32,
88+
"num_hidden_layers": 32,
89+
"rms_norm_eps": 1e-06,
90+
"vocab_size": 32000,
91+
"bos_token_id": 1,
92+
"eos_token_id": 2,
93+
"pad_token_id": 0,
94+
"use_cache": False,
95+
"use_recompute": False,
96+
"use_flash_attention": False,
97+
},
98+
"facebook/llama-13b": {
99+
"hidden_size": 5120,
100+
"initializer_range": 0.02,
101+
"intermediate_size": 13824,
102+
"max_position_embeddings": 2048,
103+
"model_type": "llama",
104+
"num_attention_heads": 40,
105+
"num_hidden_layers": 40,
106+
"rms_norm_eps": 1e-06,
107+
"vocab_size": 32000,
108+
"bos_token_id": 1,
109+
"eos_token_id": 2,
110+
"pad_token_id": 0,
111+
"use_cache": False,
112+
"use_recompute": False,
113+
"use_flash_attention": False,
114+
},
115+
"facebook/llama-30b": {
116+
"hidden_size": 6656,
117+
"initializer_range": 0.02,
118+
"intermediate_size": 17920,
119+
"max_position_embeddings": 2048,
120+
"model_type": "llama",
121+
"num_attention_heads": 52,
122+
"num_hidden_layers": 60,
123+
"rms_norm_eps": 1e-06,
124+
"vocab_size": 32000,
125+
"bos_token_id": 1,
126+
"eos_token_id": 2,
127+
"pad_token_id": 0,
128+
"use_cache": False,
129+
"use_recompute": False,
130+
"use_flash_attention": False,
131+
},
132+
"facebook/llama-65b": {
133+
"hidden_size": 8192,
134+
"initializer_range": 0.02,
135+
"intermediate_size": 22016,
136+
"max_position_embeddings": 2048,
137+
"model_type": "llama",
138+
"num_attention_heads": 64,
139+
"num_hidden_layers": 80,
140+
"rms_norm_eps": 1e-05,
141+
"vocab_size": 32000,
142+
"bos_token_id": 1,
143+
"eos_token_id": 2,
144+
"pad_token_id": 0,
145+
"use_cache": False,
146+
"use_recompute": False,
147+
"use_flash_attention": False,
148+
},
63149
}
64150

65151
# Hypothetical model weights (tiny-random-llama) for test only
@@ -168,10 +254,14 @@ def __init__(
168254
alibi=False,
169255
rope_scaling_factor=1.0,
170256
rope_scaling_type=None,
257+
embedding_output_size=None,
171258
**kwargs,
172259
):
173260
self.vocab_size = vocab_size
174261
self.hidden_size = hidden_size
262+
if embedding_output_size is None:
263+
embedding_output_size = hidden_size
264+
self.embedding_output_size = embedding_output_size
175265
self.intermediate_size = intermediate_size
176266
self.max_position_embeddings = max_position_embeddings
177267
self.seq_length = seq_length

paddlenlp/transformers/llama/modeling.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,9 @@ class LlamaRMSNorm(nn.Layer):
299299
def __init__(self, config):
300300
super().__init__()
301301
self.hidden_size = config.hidden_size
302+
self.embedding_output_size = config.embedding_output_size
302303
self.weight = paddle.create_parameter(
303-
shape=[self.hidden_size],
304+
shape=[self.embedding_output_size],
304305
dtype=paddle.get_default_dtype(),
305306
default_initializer=nn.initializer.Constant(1.0),
306307
)
@@ -465,6 +466,7 @@ class LlamaMLP(nn.Layer):
465466
def __init__(self, config):
466467
super().__init__()
467468
self.hidden_size = config.hidden_size
469+
self.embedding_output_size = config.embedding_output_size
468470
self.intermediate_size = config.intermediate_size
469471
self.tensor_parallel_degree = config.tensor_parallel_degree
470472
self.fuse_attention_ffn = config.fuse_attention_ffn
@@ -479,39 +481,41 @@ def __init__(self, config):
479481
if config.tensor_parallel_degree > 1:
480482
if config.fuse_attention_ffn:
481483
self.gate_up_fused_proj = ColumnParallelLinear(
482-
self.hidden_size,
484+
self.embedding_output_size,
483485
self.intermediate_size * 2,
484486
gather_output=False,
485487
has_bias=False,
486488
)
487489
else:
488490
self.gate_proj = ColumnParallelLinear(
489-
self.hidden_size,
491+
self.embedding_output_size,
490492
self.intermediate_size,
491493
gather_output=False,
492494
has_bias=False,
493495
)
494496
self.up_proj = ColumnParallelLinear(
495-
self.hidden_size,
497+
self.embedding_output_size,
496498
self.intermediate_size,
497499
gather_output=False,
498500
has_bias=False,
499501
)
500502

501503
self.down_proj = RowParallelLinear(
502504
self.intermediate_size,
503-
self.hidden_size,
505+
self.embedding_output_size,
504506
input_is_parallel=True,
505507
has_bias=False,
506508
)
507509
else:
508510
if config.fuse_attention_ffn:
509-
self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
511+
self.gate_up_fused_proj = nn.Linear(
512+
self.embedding_output_size, self.intermediate_size * 2, bias_attr=False
513+
)
510514
else:
511-
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
512-
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
515+
self.gate_proj = nn.Linear(self.embedding_output_size, self.intermediate_size, bias_attr=False)
516+
self.up_proj = nn.Linear(self.embedding_output_size, self.intermediate_size, bias_attr=False)
513517

514-
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False)
518+
self.down_proj = nn.Linear(self.intermediate_size, self.embedding_output_size, bias_attr=False)
515519

516520
def forward(self, x):
517521
if self.fuse_attention_ffn:
@@ -530,6 +534,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
530534

531535
self.config = config
532536
self.hidden_size = config.hidden_size
537+
self.embedding_output_size = config.embedding_output_size
533538
self.num_heads = config.num_attention_heads
534539

535540
self.head_dim = self.hidden_size // config.num_attention_heads
@@ -590,78 +595,78 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
590595
if config.tensor_parallel_degree > 1:
591596
if self.fuse_attention_qkv:
592597
self.qkv_proj = ColumnParallelLinear(
593-
self.hidden_size,
598+
self.embedding_output_size,
594599
3 * self.hidden_size,
595600
has_bias=False,
596601
gather_output=False,
597602
)
598603
else:
599604
self.q_proj = ColumnParallelLinear(
600-
self.hidden_size,
605+
self.embedding_output_size,
601606
self.hidden_size,
602607
has_bias=False,
603608
gather_output=False,
604609
)
605610
if self.kv_indices is None:
606611
self.k_proj = ColumnParallelLinear(
607-
self.hidden_size,
612+
self.embedding_output_size,
608613
self.config.num_key_value_heads * self.head_dim,
609614
has_bias=False,
610615
gather_output=False,
611616
)
612617
self.v_proj = ColumnParallelLinear(
613-
self.hidden_size,
618+
self.embedding_output_size,
614619
self.config.num_key_value_heads * self.head_dim,
615620
has_bias=False,
616621
gather_output=False,
617622
)
618623
else:
619624
self.k_proj = nn.Linear(
620-
self.hidden_size,
625+
self.embedding_output_size,
621626
self.config.num_key_value_heads * self.head_dim,
622627
bias_attr=False,
623628
)
624629
self.v_proj = nn.Linear(
625-
self.hidden_size,
630+
self.embedding_output_size,
626631
self.config.num_key_value_heads * self.head_dim,
627632
bias_attr=False,
628633
)
629634

630635
else:
631636
if self.fuse_attention_qkv:
632637
self.qkv_proj = nn.Linear(
633-
self.hidden_size,
638+
self.embedding_output_size,
634639
3 * self.hidden_size,
635640
bias_attr=False,
636641
)
637642
else:
638643
self.q_proj = nn.Linear(
639-
self.hidden_size,
644+
self.embedding_output_size,
640645
self.hidden_size,
641646
bias_attr=False,
642647
)
643648
self.k_proj = nn.Linear(
644-
self.hidden_size,
649+
self.embedding_output_size,
645650
self.config.num_key_value_heads * self.head_dim,
646651
bias_attr=False,
647652
)
648653
self.v_proj = nn.Linear(
649-
self.hidden_size,
654+
self.embedding_output_size,
650655
self.config.num_key_value_heads * self.head_dim,
651656
bias_attr=False,
652657
)
653658

654659
if config.tensor_parallel_degree > 1:
655660
self.o_proj = RowParallelLinear(
656661
self.hidden_size,
657-
self.hidden_size,
662+
self.embedding_output_size,
658663
has_bias=False,
659664
input_is_parallel=True,
660665
)
661666
else:
662667
self.o_proj = nn.Linear(
663668
self.hidden_size,
664-
self.hidden_size,
669+
self.embedding_output_size,
665670
bias_attr=False,
666671
)
667672

@@ -1078,6 +1083,7 @@ def __init__(self, config: LlamaConfig):
10781083
super().__init__(config)
10791084
self.vocab_size = config.vocab_size
10801085
self.hidden_size = config.hidden_size
1086+
self.embedding_output_size = config.embedding_output_size
10811087
self.sequence_parallel = config.sequence_parallel
10821088
self.recompute_granularity = config.recompute_granularity
10831089
self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else []
@@ -1087,13 +1093,13 @@ def __init__(self, config: LlamaConfig):
10871093
if config.tensor_parallel_degree > 1:
10881094
self.embed_tokens = mpu.VocabParallelEmbedding(
10891095
self.vocab_size,
1090-
self.hidden_size,
1096+
self.embedding_output_size,
10911097
weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()),
10921098
)
10931099
else:
10941100
self.embed_tokens = nn.Embedding(
10951101
self.vocab_size,
1096-
self.hidden_size,
1102+
self.embedding_output_size,
10971103
)
10981104

10991105
self.layers = nn.LayerList(
@@ -1115,12 +1121,10 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
11151121
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
11161122
if len(attention_mask.shape) == 2:
11171123
expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1])
1118-
# For decoding phase in generation, seq_length = 1, we don't need to add causal mask
1119-
if input_shape[-1] > 1:
1120-
combined_attention_mask = _make_causal_mask(
1121-
input_shape, past_key_values_length=past_key_values_length
1122-
)
1123-
expanded_attn_mask = expanded_attn_mask & combined_attention_mask
1124+
# For decoding phase in generation, seq_length = 1, we don't need to add causal mask. for we run pretrain, temporarily delete if
1125+
# if input_shape[-1] > 1:
1126+
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
1127+
expanded_attn_mask = expanded_attn_mask & combined_attention_mask
11241128
# [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
11251129
elif len(attention_mask.shape) == 3:
11261130
expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool")
@@ -1359,7 +1363,7 @@ def __init__(self, config: LlamaConfig):
13591363
vocab_size = config.vocab_size
13601364

13611365
self.weight = self.create_parameter(
1362-
shape=[config.hidden_size, vocab_size],
1366+
shape=[config.embedding_output_size, vocab_size],
13631367
dtype=paddle.get_default_dtype(),
13641368
)
13651369
# Must set distributed attr for Tensor Parallel !

paddlenlp/transformers/llama/tokenizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class LlamaTokenizer(PretrainedTokenizer):
3636
"vocab_file": {
3737
"__internal_testing__/micro-random-llama": "https://bj.bcebos.com/paddlenlp/models/transformers/llama/sentencepiece.bpe.model",
3838
"__internal_testing__/tiny-random-llama": "https://bj.bcebos.com/paddlenlp/models/transformers/llama/sentencepiece.bpe.model",
39+
"__internal_testing__/distributed-projection-llama-7b": "https://bj.bcebos.com/paddlenlp/models/transformers/llama/sentencepiece.bpe.model",
3940
"facebook/llama-7b": "https://bj.bcebos.com/paddlenlp/models/transformers/llama/sentencepiece.bpe.model",
4041
"facebook/llama-13b": "https://bj.bcebos.com/paddlenlp/models/transformers/llama/sentencepiece.bpe.model",
4142
"facebook/llama-30b": "https://bj.bcebos.com/paddlenlp/models/transformers/llama/sentencepiece.bpe.model",
@@ -46,6 +47,7 @@ class LlamaTokenizer(PretrainedTokenizer):
4647
pretrained_init_configuration = {
4748
"__internal_testing__/micro-random-llama": {},
4849
"__internal_testing__/tiny-random-llama": {},
50+
"__internal_testing__/distributed-projection-llama-7b": {},
4951
"facebook/llama-7b": {},
5052
"facebook/llama-13b": {},
5153
"facebook/llama-30b": {},

tests/test_tipc/benchmark/options.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from .modules.stablediffusion import StableDiffusionBenchmark
3333
except Exception:
3434
StableDiffusionBenchmark = None
35+
from paddlenlp.trainer.argparser import strtobool
36+
3537
from .modules.t5_for_conditional_generation import T5ForConditionalGenerationBenchmark
3638
from .modules.xlnet import XLNetBenchmark
3739

@@ -156,6 +158,7 @@ def get_parser():
156158
help='The option of profiler, which should be in format "key1=value1;key2=value2;key3=value3".',
157159
)
158160
parser.add_argument("--save_model", type=str, default=None, help="Directory to save models. ")
161+
parser.add_argument("--use_nsys", type=strtobool, default=False, help="Enable nsys.")
159162

160163
return parser
161164

0 commit comments

Comments
 (0)