Skip to content

Commit 0ec78aa

Browse files
authored
[NEW Model] Add jamba (#8517)
* Add jamba
1 parent 12107af commit 0ec78aa

File tree

14 files changed

+3001
-14
lines changed

14 files changed

+3001
-14
lines changed

llm/run_finetune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def neft_post_hook(module, input, output):
342342

343343
if data_args.zero_padding:
344344
if (
345-
model.base_model_prefix not in ["llama", "bloom", "chatglm", "chatglm_v2", "qwen", "mistral"]
345+
model.base_model_prefix not in ["llama", "bloom", "chatglm", "chatglm_v2", "qwen", "mistral", "jamba"]
346346
and training_args.pipeline_parallel_degree < 1
347347
):
348348
raise NotImplementedError(

llm/utils/data.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,12 @@ def get_convert_example(model):
5656
"qwen2_moe",
5757
"gpt",
5858
"yuan",
59+
"jamba",
5960
]:
6061
return convert_example_common
6162
else:
6263
raise ValueError(
63-
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral, gemma, qwen2, qwen2_moe, yuan",
64+
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral, gemma, qwen2, qwen2_moe, yuan, jamba",
6465
)
6566

6667

@@ -198,9 +199,7 @@ def convert_example_common(example, tokenizer, data_args, is_test=True, zero_pad
198199
features["position_ids"] = list(range(seq_length))
199200
if zero_padding:
200201
if flash_mask:
201-
features["attn_mask_startend_row_indices"] = (
202-
[seq_length] * seq_length
203-
)
202+
features["attn_mask_startend_row_indices"] = [seq_length] * seq_length
204203
else:
205204
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)
206205

@@ -236,13 +235,10 @@ def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, z
236235
features = {"input_ids": input_ids, "labels": labels}
237236
if zero_padding:
238237
if flash_mask:
239-
features["attn_mask_startend_row_indices"] = (
240-
[seq_length] * seq_length
241-
)
238+
features["attn_mask_startend_row_indices"] = [seq_length] * seq_length
242239
else:
243240
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)
244241

245-
246242
if "position_ids" in rounds_inputs:
247243
rounds_inputs["position_ids"] = rounds_inputs["position_ids"][:-1]
248244

@@ -252,9 +248,7 @@ def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, z
252248

253249
def convert_example_chatglm(example, tokenizer, data_args, is_test=True, zero_padding=False, flash_mask=False):
254250
if flash_mask:
255-
raise ValueError(
256-
"chatglm does not support flash mask for now!"
257-
)
251+
raise ValueError("chatglm does not support flash mask for now!")
258252
if tokenizer.chat_template is not None:
259253
# chatglm only support single-round finetune
260254
example = convert_multi_rounds_to_single_round(example, tokenizer)

paddlenlp/trainer/trainer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1827,7 +1827,18 @@ def _wrap_model(self, model, training=True):
18271827
# Multi-gpu training
18281828
if self.args.world_size > 1 and (not self.args.use_hybrid_parallel):
18291829
# MOE use DDP to broadcaset parameters.
1830-
model = paddle.DataParallel(model)
1830+
ddp_kwargs = {}
1831+
if self.args.ddp_find_unused_parameters is not None:
1832+
ddp_kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
1833+
elif isinstance(model, PretrainedModel):
1834+
# find_unused_parameters breaks checkpointing as per
1835+
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
1836+
ddp_kwargs["find_unused_parameters"] = not any(
1837+
hasattr(m, "enable_recompute") and m.enable_recompute for m in model.sublayers(include_self=True)
1838+
)
1839+
else:
1840+
ddp_kwargs["find_unused_parameters"] = True
1841+
model = paddle.DataParallel(model, **ddp_kwargs)
18311842
# Distributed training (should be after fp16 initialization)
18321843

18331844
if self.args.amp_master_grad:

paddlenlp/trainer/training_args.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,9 @@ class TrainingArguments:
343343
The list of integrations to report the results and logs to.
344344
Supported platforms are `"visualdl"`/`"wandb"`/`"tensorboard"`.
345345
`"none"` for no integrations.
346+
ddp_find_unused_parameters (`bool`, *optional*):
347+
When using distributed training, the value of the flag `find_unused_parameters` passed to
348+
`paddle.DataParallel`. Will default to `False` if recompute is used, `True` otherwise.
346349
wandb_api_key (`str`, *optional*):
347350
Weights & Biases (WandB) API key(s) for authentication with the WandB service.
348351
resume_from_checkpoint (`str`, *optional*):
@@ -762,6 +765,15 @@ class TrainingArguments:
762765
report_to: Optional[List[str]] = field(
763766
default=None, metadata={"help": "The list of integrations to report the results and logs to."}
764767
)
768+
ddp_find_unused_parameters: Optional[bool] = field(
769+
default=None,
770+
metadata={
771+
"help": (
772+
"When using distributed training, the value of the flag `find_unused_parameters` passed to "
773+
"`DataParallel`."
774+
)
775+
},
776+
)
765777
wandb_api_key: Optional[str] = field(
766778
default=None,
767779
metadata={"help": "Weights & Biases (WandB) API key(s) for authentication with the WandB service."},

paddlenlp/transformers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,6 @@
303303
from .mamba.configuration import *
304304
from .mamba.modeling import *
305305
from .mamba.tokenizer import *
306+
from .jamba.modeling import *
307+
from .jamba.configuration import *
308+
from .jamba.tokenizer import *

paddlenlp/transformers/auto/modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@
124124
("Gemma", "gemma"),
125125
("Yuan", "yuan"),
126126
("Mamba", "mamba"),
127+
("Jamba", "jamba"),
127128
]
128129
)
129130

paddlenlp/transformers/auto/tokenizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
("GemmaTokenizer", "gemma"),
101101
("YuanTokenizer", "yuan"),
102102
("MambaTokenizer", "mamba"),
103+
("JambaTokenizer", "jamba"),
103104
]
104105
)
105106

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# coding=utf-8
2+
# Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
""" Jamba model configuration"""
16+
import math
17+
18+
from ..configuration_utils import PretrainedConfig
19+
20+
__all__ = [
21+
"JambaConfig",
22+
]
23+
24+
25+
class JambaConfig(PretrainedConfig):
26+
r"""
27+
This is the configuration class to store the configuration of a [`JambaModel`]. It is used to instantiate a
28+
Jamba model according to the specified arguments, defining the model architecture. Instantiating a configuration
29+
with the defaults will yield a similar configuration to that of the Jamba-v0.1 model.
30+
31+
[ai21labs/Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)
32+
33+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34+
documentation from [`PretrainedConfig`] for more information.
35+
36+
37+
Args:
38+
vocab_size (`int`, *optional*, defaults to 65536):
39+
Vocabulary size of the Jamba model. Defines the number of different tokens that can be represented by the
40+
`inputs_ids` passed when calling [`JambaModel`]
41+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
42+
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
43+
model has a output word embedding layer.
44+
hidden_size (`int`, *optional*, defaults to 4096):
45+
Dimension of the hidden representations.
46+
intermediate_size (`int`, *optional*, defaults to 14336):
47+
Dimension of the MLP representations.
48+
num_hidden_layers (`int`, *optional*, defaults to 32):
49+
Number of hidden layers in the Transformer encoder.
50+
num_attention_heads (`int`, *optional*, defaults to 32):
51+
Number of attention heads for each attention layer in the Transformer encoder.
52+
num_key_value_heads (`int`, *optional*, defaults to 8):
53+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
54+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
55+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
56+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
57+
by meanpooling all the original heads within that group. For more details checkout [this
58+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
59+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
60+
The non-linear activation function (function or string) in the decoder.
61+
initializer_range (`float`, *optional*, defaults to 0.02):
62+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
63+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
64+
The epsilon used by the rms normalization layers.
65+
use_cache (`bool`, *optional*, defaults to `True`):
66+
Whether or not the model should return the last key/values attentions (not used by all models). Only
67+
relevant if `config.is_decoder=True`.
68+
num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
69+
Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
70+
integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
71+
logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
72+
sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
73+
significantly.
74+
output_router_logits (`bool`, *optional*, defaults to `False`):
75+
Whether or not the router logits should be returned by the model. Enabling this will also
76+
allow the model to output the auxiliary loss. See [here]() for more details
77+
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
78+
The aux loss factor for the total loss.
79+
pad_token_id (`int`, *optional*, defaults to 0):
80+
The id of the padding token.
81+
bos_token_id (`int`, *optional*, defaults to 1):
82+
The id of the "beginning-of-sequence" token.
83+
eos_token_id (`int`, *optional*, defaults to 2):
84+
The id of the "end-of-sequence" token.
85+
sliding_window (`int`, *optional*):
86+
Sliding window attention window size. If not specified, will default to `None`.
87+
max_position_embeddings (`int`, *optional*, defaults to 262144):
88+
This value doesn't have any real effect. The maximum sequence length that this model is intended to be
89+
used with. It can be used with longer sequences, but performance may degrade.
90+
attention_dropout (`float`, *optional*, defaults to 0.0):
91+
The dropout ratio for the attention probabilities.
92+
num_experts_per_tok (`int`, *optional*, defaults to 2):
93+
The number of experts to root per-token, can be also interpreted as the `top-p` routing
94+
parameter
95+
num_experts (`int`, *optional*, defaults to 16):
96+
Number of experts per Sparse MLP layer.
97+
expert_layer_period (`int`, *optional*, defaults to 2):
98+
Once in this many layers, we will have an expert layer
99+
expert_layer_offset (`int`, *optional*, defaults to 1):
100+
The first layer index that contains an expert mlp layer
101+
attn_layer_period (`int`, *optional*, defaults to 8):
102+
Once in this many layers, we will have a vanilla attention layer
103+
attn_layer_offset (`int`, *optional*, defaults to 4):
104+
The first layer index that contains a vanilla attention mlp layer
105+
use_mamba_kernels (`bool`, *optional*, defaults to `True`):
106+
Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and
107+
`causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if
108+
`True` and kernels are not available
109+
mamba_d_state (`int`, *optional*, defaults to 16):
110+
The dimension the mamba state space latents
111+
mamba_d_conv (`int`, *optional*, defaults to 4):
112+
The size of the mamba convolution kernel
113+
mamba_expand (`int`, *optional*, defaults to 2):
114+
Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
115+
mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
116+
Rank of the the mamba discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
117+
mamba_conv_bias (`bool`, *optional*, defaults to `True`):
118+
Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
119+
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
120+
Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
121+
122+
"""
123+
124+
model_type = "jamba"
125+
keys_to_ignore_at_inference = ["past_key_values"]
126+
127+
def __init__(
128+
self,
129+
vocab_size=65536,
130+
tie_word_embeddings=False,
131+
hidden_size=4096,
132+
intermediate_size=14336,
133+
num_hidden_layers=32,
134+
num_attention_heads=32,
135+
num_key_value_heads=8,
136+
hidden_act="silu",
137+
initializer_range=0.02,
138+
rms_norm_eps=1e-6,
139+
use_cache=True,
140+
num_logits_to_keep=1,
141+
output_router_logits=False,
142+
router_aux_loss_coef=0.001,
143+
pad_token_id=0,
144+
bos_token_id=1,
145+
eos_token_id=2,
146+
sliding_window=None,
147+
max_position_embeddings=262144,
148+
attention_dropout=0.0,
149+
num_experts_per_tok=2,
150+
num_experts=16,
151+
expert_layer_period=2,
152+
expert_layer_offset=1,
153+
attn_layer_period=8,
154+
attn_layer_offset=4,
155+
use_mamba_kernels=True,
156+
mamba_d_state=16,
157+
mamba_d_conv=4,
158+
mamba_expand=2,
159+
mamba_dt_rank="auto",
160+
mamba_conv_bias=True,
161+
mamba_proj_bias=False,
162+
**kwargs,
163+
):
164+
kwargs["return_dict"] = kwargs.pop("return_dict", True)
165+
super().__init__(
166+
pad_token_id=pad_token_id,
167+
bos_token_id=bos_token_id,
168+
eos_token_id=eos_token_id,
169+
tie_word_embeddings=tie_word_embeddings,
170+
**kwargs,
171+
)
172+
self.vocab_size = vocab_size
173+
self.tie_word_embeddings = tie_word_embeddings
174+
self.hidden_size = hidden_size
175+
self.intermediate_size = intermediate_size
176+
self.num_hidden_layers = num_hidden_layers
177+
self.num_attention_heads = num_attention_heads
178+
self.sliding_window = sliding_window
179+
self.max_position_embeddings = max_position_embeddings
180+
self.attention_dropout = attention_dropout
181+
182+
# for backward compatibility
183+
if num_key_value_heads is None:
184+
num_key_value_heads = num_attention_heads
185+
186+
self.num_key_value_heads = num_key_value_heads
187+
self.hidden_act = hidden_act
188+
self.initializer_range = initializer_range
189+
self.rms_norm_eps = rms_norm_eps
190+
191+
self.use_cache = use_cache
192+
self.num_logits_to_keep = num_logits_to_keep
193+
self.output_router_logits = output_router_logits
194+
self.router_aux_loss_coef = router_aux_loss_coef
195+
196+
self.num_experts_per_tok = num_experts_per_tok
197+
self.num_experts = num_experts
198+
self.expert_layer_period = expert_layer_period
199+
self.expert_layer_offset = expert_layer_offset
200+
self.attn_layer_period = attn_layer_period
201+
self.attn_layer_offset = attn_layer_offset
202+
203+
self.use_mamba_kernels = use_mamba_kernels
204+
self.mamba_d_state = mamba_d_state
205+
self.mamba_d_conv = mamba_d_conv
206+
self.mamba_expand = mamba_expand
207+
self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank
208+
self.mamba_conv_bias = mamba_conv_bias
209+
self.mamba_proj_bias = mamba_proj_bias
210+
211+
@property
212+
def layers_block_type(self):
213+
return [
214+
"attention" if i % self.attn_layer_period == self.attn_layer_offset else "mamba"
215+
for i in range(self.num_hidden_layers)
216+
]
217+
218+
@property
219+
def layers_num_experts(self):
220+
return [
221+
self.num_experts if i % self.expert_layer_period == self.expert_layer_offset else 1
222+
for i in range(self.num_hidden_layers)
223+
]

0 commit comments

Comments
 (0)