Skip to content

Commit 590cee9

Browse files
authored
Support Llama3 (#8315)
* support llama-3 * Add llama-3 tokenizer * fix for llama3
1 parent d4062e5 commit 590cee9

File tree

5 files changed

+312
-12
lines changed

5 files changed

+312
-12
lines changed

llm/finetune_generation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
AutoConfig,
4646
AutoModelForCausalLM,
4747
AutoTokenizer,
48+
Llama3Tokenizer,
4849
LlamaTokenizer,
4950
)
5051
from paddlenlp.utils.log import logger
@@ -232,7 +233,7 @@ def neft_post_hook(module, input, output):
232233
if tokenizer.chat_template is not None:
233234
data_args.eval_with_do_generation = False
234235

235-
if isinstance(tokenizer, LlamaTokenizer):
236+
if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, Llama3Tokenizer):
236237
tokenizer.pad_token_id = tokenizer.eos_token_id
237238

238239
if data_args.dataset_name_or_path is None:

paddlenlp/transformers/auto/tokenizer.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,13 +189,20 @@ def _get_tokenizer_class_from_config(cls, pretrained_model_name_or_path, config_
189189
init_class = init_kwargs.pop("tokenizer_class", None)
190190

191191
if init_class:
192-
class_name = cls._name_mapping[init_class]
193-
import_class = import_module(f"paddlenlp.transformers.{class_name}.tokenizer")
194-
tokenizer_class = getattr(import_class, init_class)
195-
if use_fast:
196-
fast_tokenizer_class = cls._get_fast_tokenizer_class(init_class, class_name)
197-
tokenizer_class = fast_tokenizer_class if fast_tokenizer_class else tokenizer_class
198-
return tokenizer_class
192+
if init_class in cls._name_mapping:
193+
class_name = cls._name_mapping[init_class]
194+
import_class = import_module(f"paddlenlp.transformers.{class_name}.tokenizer")
195+
tokenizer_class = getattr(import_class, init_class)
196+
if use_fast:
197+
fast_tokenizer_class = cls._get_fast_tokenizer_class(init_class, class_name)
198+
tokenizer_class = fast_tokenizer_class if fast_tokenizer_class else tokenizer_class
199+
return tokenizer_class
200+
else:
201+
import_class = import_module("paddlenlp.transformers")
202+
tokenizer_class = getattr(import_class, init_class, None)
203+
assert tokenizer_class is not None, f"Can't find tokenizer {init_class}"
204+
return tokenizer_class
205+
199206
# If no `init_class`, we use pattern recognition to recognize the tokenizer class.
200207
else:
201208
# TODO: Potential issue https://github.com/PaddlePaddle/PaddleNLP/pull/3786#discussion_r1024689810

paddlenlp/transformers/llama/configuration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def __init__(
147147
num_key_value_heads=None,
148148
initializer_range=0.02,
149149
rms_norm_eps=1e-6,
150+
rope_theta=10000.0,
150151
use_cache=True,
151152
use_recompute=False,
152153
recompute_granularity="full",
@@ -188,6 +189,7 @@ def __init__(
188189

189190
self.initializer_range = initializer_range
190191
self.rms_norm_eps = rms_norm_eps
192+
self.rope_theta = rope_theta
191193

192194
self.use_cache = use_cache
193195
self.use_recompute = use_recompute

paddlenlp/transformers/llama/modeling.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -813,24 +813,28 @@ def _init_rope(self):
813813
self.rotary_emb = LlamaRotaryEmbedding(
814814
self.head_dim,
815815
max_position_embeddings=self.max_position_embeddings,
816+
base=self.config.rope_theta,
816817
)
817818
elif self.config.rope_scaling_type == "linear":
818819
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
819820
self.head_dim,
820821
max_position_embeddings=self.max_position_embeddings,
821822
scaling_factor=self.config.rope_scaling_factor,
823+
base=self.config.rope_theta,
822824
)
823825
elif self.config.rope_scaling_type == "ntk":
824826
self.rotary_emb = LlamaNTKScalingRotaryEmbedding(
825827
self.head_dim,
826828
max_position_embeddings=self.max_position_embeddings,
827829
scaling_factor=self.config.rope_scaling_factor,
830+
base=self.config.rope_theta,
828831
)
829832
elif self.config.rope_scaling_type == "dynamic_ntk":
830833
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
831834
self.head_dim,
832835
max_position_embeddings=self.max_position_embeddings,
833836
scaling_factor=self.config.rope_scaling_factor,
837+
base=self.config.rope_theta,
834838
)
835839
else:
836840
raise ValueError(f"Unknown RoPE scaling type {self.config.rope_scaling_type}")
@@ -903,6 +907,7 @@ def forward(
903907
query_states = self.q_proj(hidden_states)
904908
key_states = self.k_proj(hidden_states)
905909
value_states = self.v_proj(hidden_states)
910+
906911
if self.reshard_layer is not None:
907912
if self.sequence_parallel:
908913
assert self.seq_length % self.config.sep_parallel_degree == 0
@@ -1027,7 +1032,6 @@ def forward(
10271032
value_states = paddle.concat([past_key_value[1], value_states], axis=1)
10281033

10291034
past_key_value = (key_states, value_states) if use_cache else None
1030-
10311035
if self.kv_indices is not None:
10321036
key_states = paddle.index_select(key_states, self.kv_indices, axis=2)
10331037
value_states = paddle.index_select(value_states, self.kv_indices, axis=2)
@@ -1036,7 +1040,7 @@ def forward(
10361040
# repeat k/v heads if n_kv_heads < n_heads
10371041
# paddle version > 2.6 or develop support flash-attn with gqa/mqa
10381042
paddle_version = float(paddle.__version__[:3])
1039-
if (paddle_version != 0.0) and (paddle_version <= 2.6):
1043+
if not self.config.use_flash_attention or ((paddle_version != 0.0) and (paddle_version <= 2.6)):
10401044
key_states = repeat_kv(key_states, self.num_key_value_groups)
10411045
value_states = repeat_kv(value_states, self.num_key_value_groups)
10421046

@@ -1560,7 +1564,6 @@ def forward(
15601564
else:
15611565
attention_mask = attention_mask.astype("bool")
15621566
hidden_states = inputs_embeds
1563-
15641567
# decoder layers
15651568
all_hidden_states = () if output_hidden_states else None
15661569
all_self_attns = () if output_attentions else None

0 commit comments

Comments
 (0)