Skip to content

Commit 89cb65d

Browse files
the draft of add gpt-oss model
1 parent 8d77d18 commit 89cb65d

File tree

3 files changed

+62
-49
lines changed

3 files changed

+62
-49
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .base_weight import BaseWeightTpl
33
from lightllm.utils.dist_utils import get_current_device_id
44

5+
# For special weight
56
class DummyWeight(BaseWeightTpl):
67
def __init__(self, weight_name, data_type):
78
super().__init__()
@@ -15,10 +16,10 @@ def load_hf_weights(self, weights):
1516

1617
def verify_load(self):
1718
load_ok = True
18-
# Verify weight. The weight must be not None.
1919
load_ok = load_ok and self.weight is not None
2020
return load_ok
2121

22+
2223
class NormWeight(BaseWeightTpl):
2324
def __init__(self, weight_name, data_type, bias_name=None):
2425
super().__init__()

lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,40 @@
1414

1515
logger = init_logger(__name__)
1616

17+
1718
class GptOssTransformerLayerInfer(LlamaTransformerLayerInfer):
1819
def __init__(self, layer_num, network_config, mode=[]):
1920
super().__init__(layer_num, network_config, mode)
20-
self.hidden_size = self.network_config_['hidden_size']
21+
self.hidden_size = self.network_config_["hidden_size"]
2122
self.alpha = 1.702
2223
self.limit = 7.0
23-
self.top_k = network_config['num_experts_per_tok']
24-
self.sliding_window = network_config['sliding_window']
24+
self.top_k = network_config["num_experts_per_tok"]
25+
self.sliding_window = network_config["sliding_window"]
2526
self.head_dim_ = network_config["head_dim"]
2627

2728
def _bind_attention(self):
2829
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
2930
self._context_attention_kernel = self._conext_sliding_attention_flashattention
3031
self._token_attention_kernel = self._token_sliding_attention_flashattention
31-
32+
3233
def _bind_norm(self):
3334
self._att_norm = self._att_norm
3435
self._ffn_norm = self._ffn_norm
3536
return
3637

37-
def _experts(self, hidden_states: torch.Tensor, router_indices, routing_weights, layer_weight: GptOssTransformerLayerWeight):
38+
def _experts(
39+
self, hidden_states: torch.Tensor, router_indices, routing_weights, layer_weight: GptOssTransformerLayerWeight
40+
):
3841
batch_size = hidden_states.shape[0]
3942
hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
4043
num_experts = routing_weights.shape[1]
4144

4245
hidden_states = hidden_states.repeat(num_experts, 1)
4346
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
44-
gate_up = torch.bmm(hidden_states, layer_weight.gate_up_proj_weight) + layer_weight.gate_up_proj_bias.weight[..., None, :]
47+
gate_up = (
48+
torch.bmm(hidden_states, layer_weight.gate_up_proj_weight)
49+
+ layer_weight.gate_up_proj_bias.weight[..., None, :]
50+
)
4551
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
4652
gate = gate.clamp(min=None, max=self.limit)
4753
up = up.clamp(min=-self.limit, max=self.limit)
@@ -52,21 +58,17 @@ def _experts(self, hidden_states: torch.Tensor, router_indices, routing_weights,
5258
next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
5359
next_states = next_states.sum(dim=0)
5460
return next_states
55-
56-
def _att_norm(
57-
self, input, infer_state, layer_weight
58-
) -> torch.Tensor:
61+
62+
def _att_norm(self, input, infer_state, layer_weight) -> torch.Tensor:
5963
out = self.alloc_tensor(input.shape, input.dtype)
6064
out = self._gpt_oss_rmsnorm(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_)
6165
return out
62-
63-
def _ffn_norm(
64-
self, input, infer_state, layer_weight
65-
) -> torch.Tensor:
66+
67+
def _ffn_norm(self, input, infer_state, layer_weight) -> torch.Tensor:
6668
out = self.alloc_tensor(input.shape, input.dtype)
6769
out = self._gpt_oss_rmsnorm(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_)
6870
return out
69-
71+
7072
def _gpt_oss_rmsnorm(self, hidden_states, weight, eps=1e-6):
7173
input_dtype = hidden_states.dtype
7274
hidden_states = hidden_states.to(torch.float32)
@@ -81,18 +83,24 @@ def _router(self, hidden_states, layer_weight: GptOssTransformerLayerWeight):
8183
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
8284
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
8385
return router_scores, router_indices
84-
85-
def _ffn(self, input, infer_state: FlashAttentionStateInfo, layer_weight: GptOssTransformerLayerWeight) -> torch.Tensor:
86+
87+
def _ffn(
88+
self, input, infer_state: FlashAttentionStateInfo, layer_weight: GptOssTransformerLayerWeight
89+
) -> torch.Tensor:
8690
router_scores, router_indices = self._router(input, layer_weight) # (num_experts, seq_len)
87-
routed_out = self._experts(input, router_indices=router_indices, routing_weights=router_scores, layer_weight=layer_weight)
91+
routed_out = self._experts(
92+
input, router_indices=router_indices, routing_weights=router_scores, layer_weight=layer_weight
93+
)
8894
return routed_out
89-
90-
def _conext_sliding_attention_flashattention(self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None):
91-
if self.network_config_['layer_types'][self.layer_num_] == "sliding_attention":
92-
window_size = (self.sliding_window-1, self.sliding_window-1)
95+
96+
def _conext_sliding_attention_flashattention(
97+
self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None
98+
):
99+
if self.network_config_["layer_types"][self.layer_num_] == "sliding_attention":
100+
window_size = (self.sliding_window - 1, self.sliding_window - 1)
93101
else:
94102
window_size = (-1, -1)
95-
103+
96104
cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape(
97105
-1, 1, self.tp_k_head_num_, self.head_dim_
98106
)
@@ -114,7 +122,7 @@ def _conext_sliding_attention_flashattention(self, q, kv, infer_state: FlashAtte
114122
max_seqlen_q=infer_state.q_max_seq_len,
115123
softmax_scale=sm_scale,
116124
causal=True,
117-
window_size=(-1, -1),
125+
window_size=window_size,
118126
softcap=0.0,
119127
k_descale=k_descale,
120128
v_descale=v_descale,
@@ -124,11 +132,11 @@ def _conext_sliding_attention_flashattention(self, q, kv, infer_state: FlashAtte
124132
return o
125133

126134
def _token_sliding_attention_flashattention(self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None):
127-
if self.network_config_['layer_types'][self.layer_num_] == "sliding_attention":
128-
window_size = (self.sliding_window-1, self.sliding_window-1)
135+
if self.network_config_["layer_types"][self.layer_num_] == "sliding_attention":
136+
window_size = (self.sliding_window - 1, self.sliding_window - 1)
129137
else:
130138
window_size = (-1, -1)
131-
139+
132140
cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape(
133141
-1, 1, self.tp_k_head_num_, self.head_dim_
134142
)
@@ -157,4 +165,4 @@ def _token_sliding_attention_flashattention(self, q, infer_state: FlashAttention
157165
return_softmax_lse=False,
158166
sinks=layer_weight.attn_sinks.weight,
159167
)
160-
return o
168+
return o

lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
-6.0,
3030
]
3131

32+
3233
class GptOssTransformerLayerWeight(LlamaTransformerLayerWeight):
3334
def __init__(
3435
self,
@@ -75,14 +76,14 @@ def _init_weight_names(self):
7576
self._o_bias_name = f"model.layers.{self.layer_num_}.self_attn.o_proj.bias"
7677

7778
# MOE Layers
78-
# model.layers.0.mlp.experts.down_proj_bias [32, 2 880]
79-
# model.layers.0.mlp.experts.down_proj_blocks [32, 2 880, 90, 16]
80-
# model.layers.0.mlp.experts.down_proj_scales [32, 2 880, 90]
81-
# model.layers.0.mlp.experts.gate_up_proj_bias [32, 5 760]
82-
# model.layers.0.mlp.experts.gate_up_proj_blocks [32, 5 760, 90, 16]
83-
# model.layers.0.mlp.experts.gate_up_proj_scales [32, 5 760, 90]
84-
# model.layers.0.mlp.router.bias [32]
85-
# model.layers.0.mlp.router.weight [32, 2 880]
79+
# model.layers.0.mlp.experts.down_proj_bias [32, 2 880]
80+
# model.layers.0.mlp.experts.down_proj_blocks [32, 2 880, 90, 16]
81+
# model.layers.0.mlp.experts.down_proj_scales [32, 2 880, 90]
82+
# model.layers.0.mlp.experts.gate_up_proj_bias [32, 5 760]
83+
# model.layers.0.mlp.experts.gate_up_proj_blocks [32, 5 760, 90, 16]
84+
# model.layers.0.mlp.experts.gate_up_proj_scales [32, 5 760, 90]
85+
# model.layers.0.mlp.router.bias [32]
86+
# model.layers.0.mlp.router.weight [32, 2 880]
8687

8788
self._router_bias_name = f"model.layers.{self.layer_num_}.mlp.router.bias"
8889
self._router_weight_name = f"model.layers.{self.layer_num_}.mlp.router.weight"
@@ -108,25 +109,28 @@ def _post_weight_process(self):
108109
blocks=self.down_proj_weight_blocks.weight,
109110
scales=self.down_proj_weight_scales.weight,
110111
dtype=torch.bfloat16,
111-
)[:, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :]
112-
# (32, 1440, 2880)
112+
)[
113+
:, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
114+
] # (32, 1440, 2880)
113115

114116
self.gate_up_proj_weight = self._convert_moe_packed_tensors(
115117
blocks=self.gate_up_proj_weight_blocks.weight,
116118
scales=self.gate_up_proj_weight_scales.weight,
117119
dtype=torch.bfloat16,
118-
) # (32, 2880, 5760)
120+
) # (32, 2880, 5760)
119121
expert_num = self.gate_up_proj_weight.shape[0]
120122
self.gate_up_proj_weight = self.gate_up_proj_weight.reshape(expert_num, -1, 2, self.moe_intermediate_size)[
121123
:, :, :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)
122-
].reshape(expert_num, -1, 2*self.split_inter_size)
123-
# (32, 2880, 2880)
124-
125-
self.gate_up_proj_bias.weight = self.gate_up_proj_bias.weight.reshape(expert_num, 2, self.moe_intermediate_size)[
126-
:, :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)
127-
].reshape(expert_num, 2*self.split_inter_size)
128-
# (32, 2880)
129-
124+
].reshape(
125+
expert_num, -1, 2 * self.split_inter_size
126+
) # (32, 2880, 2880)
127+
128+
self.gate_up_proj_bias.weight = self.gate_up_proj_bias.weight.reshape(
129+
expert_num, 2, self.moe_intermediate_size
130+
)[:, :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)].reshape(
131+
expert_num, 2 * self.split_inter_size
132+
) # (32, 2880)
133+
130134
def _convert_moe_packed_tensors(
131135
self,
132136
blocks,
@@ -179,4 +183,4 @@ def _convert_moe_packed_tensors(
179183

180184
out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
181185
del blocks, scales, lut
182-
return out.transpose(1, 2).contiguous()
186+
return out.transpose(1, 2).contiguous()

0 commit comments

Comments
 (0)