Skip to content

Commit 0cd6eb3

Browse files
【model】add Gpt oss model support sft/lora and infer (#2555)
Co-authored-by: wangyanbo05 <[email protected]> Co-authored-by: YB <[email protected]>
1 parent 0ee3767 commit 0cd6eb3

File tree

10 files changed

+2017
-0
lines changed

10 files changed

+2017
-0
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
{
2+
"model_name_or_path": "../gpt-oss-model-bf16",
3+
"dataset_name_or_path": "./data",
4+
"output_dir": "./checkpoints/gptoss_paddle_sft_ckpts",
5+
"overwrite_output_dir": false,
6+
"per_device_train_batch_size": 1,
7+
"gradient_accumulation_steps": 4,
8+
"per_device_eval_batch_size": 8,
9+
"eval_accumulation_steps":16,
10+
"num_train_epochs": 1,
11+
"learning_rate": 3e-05,
12+
"warmup_steps": 10,
13+
"logging_steps": 1,
14+
"evaluation_strategy": "epoch",
15+
"save_strategy": "epoch",
16+
"src_length": 1024,
17+
"max_length": 2048,
18+
"bf16": true,
19+
"fp16_opt_level": "O2",
20+
"do_train": true,
21+
"do_eval": false,
22+
"disable_tqdm": true,
23+
"load_best_model_at_end": true,
24+
"eval_with_do_generation": false,
25+
"metric_for_best_model": "accuracy",
26+
"recompute": false,
27+
"save_total_limit": 1,
28+
"tensor_parallel_degree": 4,
29+
"pipeline_parallel_degree": 1,
30+
"sharding": "stage2",
31+
"zero_padding": false,
32+
"unified_checkpoint": true,
33+
"use_flash_attention": false,
34+
"lora": true,
35+
"lora_rank": 8,
36+
"rslora": false,
37+
"lora_plus_scale": 1.0,
38+
"pissa": false,
39+
"use_quick_lora": false,
40+
"lora_use_mixer": false,
41+
"use_mora": false
42+
}

paddleformers/transformers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@
174174
"ernie4_5_moe.configuration": ["Ernie4_5_MoeConfig"],
175175
"ernie4_5_moe.modeling": ["Ernie4_5_MoeModel", "Ernie4_5_MoeForCausalLM", "Ernie4_5_MoeForCausalLMPipe"],
176176
"export": ["export_model"],
177+
"gpt_oss.configuration": ["GptOssConfig"],
178+
"gpt_oss.modeling": ["GptOssModel", "GptOssForCausalLM"],
177179
"llama.configuration": [
178180
"LLAMA_PRETRAINED_INIT_CONFIGURATION",
179181
"LlamaConfig",
@@ -400,6 +402,7 @@
400402
from .qwen2_moe import *
401403
from .qwen3 import *
402404
from .qwen3_moe import *
405+
from .gpt_oss import *
403406
else:
404407
sys.modules[__name__] = _LazyModule(
405408
__name__,

paddleformers/transformers/auto/configuration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
("qwen2_moe", "Qwen2MoeConfig"),
4545
("qwen3", "Qwen3Config"),
4646
("qwen3_moe", "Qwen3MoeConfig"),
47+
("gpt_oss", "GptOssConfig"),
4748
]
4849
)
4950

paddleformers/transformers/auto/modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
("Qwen3", "qwen3"),
6363
("Qwen2Moe", "qwen2_moe"),
6464
("Qwen3Moe", "qwen3_moe"),
65+
("GptOss", "gpt_oss"),
6566
]
6667
)
6768

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .configuration import *
16+
from .modeling import *
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# from ..configuration_utils import PretrainedConfig, layer_type_validation
16+
from ..configuration_utils import PretrainedConfig
17+
18+
# from ...modeling_rope_utils import rope_config_validation
19+
20+
21+
class GptOssConfig(PretrainedConfig):
22+
r"""
23+
This will yield a configuration to that of the BERT
24+
[google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture.
25+
"""
26+
27+
model_type = "gpt_oss"
28+
29+
def __init__(
30+
self,
31+
num_hidden_layers: int = 24,
32+
num_local_experts: int = 128,
33+
vocab_size: int = 201088,
34+
hidden_size: int = 2880,
35+
intermediate_size: int = 2880,
36+
head_dim: int = 64,
37+
num_attention_heads: int = 64,
38+
num_key_value_heads: int = 8,
39+
sliding_window: int = 128,
40+
rope_theta: float = 150000.0,
41+
tie_word_embeddings=False,
42+
hidden_act: str = "silu",
43+
initializer_range: float = 0.02,
44+
max_position_embeddings=131072,
45+
rms_norm_eps: float = 1e-5,
46+
rope_scaling={"rope_type": "yarn", "factor": 32.0, "beta_fast": 32.0, "beta_slow": 1.0, "truncate": False},
47+
attention_dropout: float = 0.0,
48+
num_experts_per_tok=4,
49+
router_aux_loss_coef: float = 0.9,
50+
output_router_logits=False,
51+
use_cache=True,
52+
layer_types=None,
53+
**kwargs,
54+
):
55+
self.vocab_size = vocab_size
56+
self.hidden_size = hidden_size
57+
self.intermediate_size = intermediate_size
58+
self.num_hidden_layers = num_hidden_layers
59+
self.num_attention_heads = num_attention_heads
60+
self.num_experts = num_local_experts
61+
self.sliding_window = sliding_window
62+
self.num_experts_per_tok = num_experts_per_tok
63+
# for backward compatibility
64+
if num_key_value_heads is None:
65+
num_key_value_heads = num_attention_heads
66+
67+
self.num_key_value_heads = num_key_value_heads
68+
self.hidden_act = hidden_act
69+
self.initializer_range = initializer_range
70+
self.rms_norm_eps = rms_norm_eps
71+
self.rope_theta = rope_theta
72+
self.rope_scaling = rope_scaling
73+
self.attention_dropout = attention_dropout
74+
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
75+
self.layer_types = layer_types
76+
if self.layer_types is None:
77+
self.layer_types = [
78+
"sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
79+
]
80+
# layer_type_validation(self.layer_types)
81+
82+
# Validate the correctness of rotary position embeddings parameters
83+
# BC: if there is a 'type' field, copy it it to 'rope_type'.
84+
if self.rope_scaling is not None and "type" in self.rope_scaling:
85+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
86+
# rope_config_validation(self)
87+
88+
self.attention_bias = True
89+
self.max_position_embeddings = max_position_embeddings
90+
self.router_aux_loss_coef = router_aux_loss_coef
91+
self.output_router_logits = output_router_logits
92+
self.use_cache = use_cache
93+
self.use_bias = False
94+
95+
super().__init__(
96+
tie_word_embeddings=tie_word_embeddings,
97+
**kwargs,
98+
)
99+
100+
101+
__all__ = ["GptOssConfig"]

0 commit comments

Comments
 (0)