Skip to content

Commit 8d77d18

Browse files
the draft of add gpt-oss model
1 parent 5ea50a9 commit 8d77d18

File tree

8 files changed

+404
-0
lines changed

8 files changed

+404
-0
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,22 @@
22
from .base_weight import BaseWeightTpl
33
from lightllm.utils.dist_utils import get_current_device_id
44

5+
class DummyWeight(BaseWeightTpl):
6+
def __init__(self, weight_name, data_type):
7+
super().__init__()
8+
self.weight_name = weight_name
9+
self.data_type_ = data_type
10+
self.weight = None
11+
12+
def load_hf_weights(self, weights):
13+
if self.weight_name in weights:
14+
self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id())
15+
16+
def verify_load(self):
17+
load_ok = True
18+
# Verify weight. The weight must be not None.
19+
load_ok = load_ok and self.weight is not None
20+
return load_ok
521

622
class NormWeight(BaseWeightTpl):
723
def __init__(self, weight_name, data_type, bias_name=None):

lightllm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@
3535
Tarsier2Qwen2VLTpPartModel,
3636
Tarsier2LlamaTpPartModel,
3737
)
38+
from lightllm.models.gpt_oss.model import GptOssTpPartModel
3839
from .registry import get_model

lightllm/models/gpt_oss/__init__.py

Whitespace-only changes.

lightllm/models/gpt_oss/layer_infer/__init__.py

Whitespace-only changes.
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn import functional as F
4+
import numpy as np
5+
from functools import partial
6+
from typing import Optional
7+
8+
from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight
9+
from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo
10+
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
11+
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
12+
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
13+
from lightllm.utils.log_utils import init_logger
14+
15+
logger = init_logger(__name__)
16+
17+
class GptOssTransformerLayerInfer(LlamaTransformerLayerInfer):
18+
def __init__(self, layer_num, network_config, mode=[]):
19+
super().__init__(layer_num, network_config, mode)
20+
self.hidden_size = self.network_config_['hidden_size']
21+
self.alpha = 1.702
22+
self.limit = 7.0
23+
self.top_k = network_config['num_experts_per_tok']
24+
self.sliding_window = network_config['sliding_window']
25+
self.head_dim_ = network_config["head_dim"]
26+
27+
def _bind_attention(self):
28+
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
29+
self._context_attention_kernel = self._conext_sliding_attention_flashattention
30+
self._token_attention_kernel = self._token_sliding_attention_flashattention
31+
32+
def _bind_norm(self):
33+
self._att_norm = self._att_norm
34+
self._ffn_norm = self._ffn_norm
35+
return
36+
37+
def _experts(self, hidden_states: torch.Tensor, router_indices, routing_weights, layer_weight: GptOssTransformerLayerWeight):
38+
batch_size = hidden_states.shape[0]
39+
hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
40+
num_experts = routing_weights.shape[1]
41+
42+
hidden_states = hidden_states.repeat(num_experts, 1)
43+
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
44+
gate_up = torch.bmm(hidden_states, layer_weight.gate_up_proj_weight) + layer_weight.gate_up_proj_bias.weight[..., None, :]
45+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
46+
gate = gate.clamp(min=None, max=self.limit)
47+
up = up.clamp(min=-self.limit, max=self.limit)
48+
glu = gate * torch.sigmoid(gate * self.alpha)
49+
next_states = torch.bmm(((up + 1) * glu), layer_weight.down_proj_weight)
50+
next_states = next_states + layer_weight.down_proj_bias.weight[..., None, :]
51+
next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
52+
next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
53+
next_states = next_states.sum(dim=0)
54+
return next_states
55+
56+
def _att_norm(
57+
self, input, infer_state, layer_weight
58+
) -> torch.Tensor:
59+
out = self.alloc_tensor(input.shape, input.dtype)
60+
out = self._gpt_oss_rmsnorm(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_)
61+
return out
62+
63+
def _ffn_norm(
64+
self, input, infer_state, layer_weight
65+
) -> torch.Tensor:
66+
out = self.alloc_tensor(input.shape, input.dtype)
67+
out = self._gpt_oss_rmsnorm(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_)
68+
return out
69+
70+
def _gpt_oss_rmsnorm(self, hidden_states, weight, eps=1e-6):
71+
input_dtype = hidden_states.dtype
72+
hidden_states = hidden_states.to(torch.float32)
73+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
74+
hidden_states = hidden_states * torch.rsqrt(variance + eps)
75+
return (weight * hidden_states).to(input_dtype) # main diff with Llama
76+
77+
def _router(self, hidden_states, layer_weight: GptOssTransformerLayerWeight):
78+
hidden_states = hidden_states.reshape(-1, self.hidden_size)
79+
router_logits = layer_weight.moe_gate.mm(hidden_states) # (seq_len, num_experts)
80+
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
81+
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
82+
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
83+
return router_scores, router_indices
84+
85+
def _ffn(self, input, infer_state: FlashAttentionStateInfo, layer_weight: GptOssTransformerLayerWeight) -> torch.Tensor:
86+
router_scores, router_indices = self._router(input, layer_weight) # (num_experts, seq_len)
87+
routed_out = self._experts(input, router_indices=router_indices, routing_weights=router_scores, layer_weight=layer_weight)
88+
return routed_out
89+
90+
def _conext_sliding_attention_flashattention(self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None):
91+
if self.network_config_['layer_types'][self.layer_num_] == "sliding_attention":
92+
window_size = (self.sliding_window-1, self.sliding_window-1)
93+
else:
94+
window_size = (-1, -1)
95+
96+
cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape(
97+
-1, 1, self.tp_k_head_num_, self.head_dim_
98+
)
99+
cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][
100+
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
101+
].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_)
102+
q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_)
103+
k_descale, v_descale = None, None # disable quantization
104+
Lq = q.shape[-1]
105+
sm_scale = 1.0 / (Lq ** 0.5)
106+
o = flash_attn_with_kvcache(
107+
q=q,
108+
k_cache=cache_k,
109+
v_cache=cache_v,
110+
page_table=infer_state.page_table,
111+
cache_seqlens=infer_state.b_seq_len,
112+
cu_seqlens_q=infer_state.cu_seqlens_q,
113+
cu_seqlens_k_new=infer_state.cu_seqlens_k,
114+
max_seqlen_q=infer_state.q_max_seq_len,
115+
softmax_scale=sm_scale,
116+
causal=True,
117+
window_size=(-1, -1),
118+
softcap=0.0,
119+
k_descale=k_descale,
120+
v_descale=v_descale,
121+
return_softmax_lse=False,
122+
sinks=layer_weight.attn_sinks.weight,
123+
)
124+
return o
125+
126+
def _token_sliding_attention_flashattention(self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None):
127+
if self.network_config_['layer_types'][self.layer_num_] == "sliding_attention":
128+
window_size = (self.sliding_window-1, self.sliding_window-1)
129+
else:
130+
window_size = (-1, -1)
131+
132+
cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape(
133+
-1, 1, self.tp_k_head_num_, self.head_dim_
134+
)
135+
cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][
136+
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
137+
].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_)
138+
q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_)
139+
k_descale, v_descale = None, None # disable quantization
140+
Lq = q.shape[-1]
141+
sm_scale = 1.0 / (Lq ** 0.5)
142+
o = flash_attn_with_kvcache(
143+
q=q,
144+
k_cache=cache_k,
145+
v_cache=cache_v,
146+
page_table=infer_state.page_table,
147+
cache_seqlens=infer_state.b_seq_len,
148+
cu_seqlens_q=infer_state.cu_seqlens_q,
149+
cu_seqlens_k_new=infer_state.cu_seqlens_k,
150+
max_seqlen_q=1,
151+
softmax_scale=sm_scale,
152+
causal=True,
153+
window_size=window_size,
154+
softcap=0.0,
155+
k_descale=k_descale,
156+
v_descale=v_descale,
157+
return_softmax_lse=False,
158+
sinks=layer_weight.attn_sinks.weight,
159+
)
160+
return o

lightllm/models/gpt_oss/layer_weights/__init__.py

Whitespace-only changes.
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import os
2+
import torch
3+
import numpy as np
4+
5+
from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.rowmm_weight import ROWMMWeight
6+
from lightllm.common.basemodel.layer_weights.meta_weights.norm_weight import DummyWeight
7+
from lightllm.models.bloom import model
8+
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
9+
from lightllm.utils.log_utils import init_logger
10+
11+
logger = init_logger(__name__)
12+
13+
FP4_VALUES = [
14+
+0.0,
15+
+0.5,
16+
+1.0,
17+
+1.5,
18+
+2.0,
19+
+3.0,
20+
+4.0,
21+
+6.0,
22+
-0.0,
23+
-0.5,
24+
-1.0,
25+
-1.5,
26+
-2.0,
27+
-3.0,
28+
-4.0,
29+
-6.0,
30+
]
31+
32+
class GptOssTransformerLayerWeight(LlamaTransformerLayerWeight):
33+
def __init__(
34+
self,
35+
layer_num,
36+
data_type,
37+
network_config,
38+
mode=[],
39+
quant_cfg=None,
40+
):
41+
super().__init__(layer_num, data_type, network_config, mode, quant_cfg)
42+
return
43+
44+
def _init_moe(self):
45+
moe_mode = os.getenv("MOE_MODE", "TP")
46+
assert moe_mode in ["TP"], "For now, GPT-OSS type model only support MOE TP mode."
47+
self.moe_gate = ROWMMWeight(
48+
weight_name=self._router_weight_name,
49+
data_type=self.data_type_,
50+
layer_num=self.layer_num_,
51+
bias_name=self._router_bias_name,
52+
name="moe_gate",
53+
tp_rank=0,
54+
tp_world_size=1,
55+
)
56+
self.down_proj_bias = DummyWeight(self._down_bias_name, torch.bfloat16)
57+
self.down_proj_weight_blocks = DummyWeight(self._down_blocks_name, torch.uint8)
58+
self.down_proj_weight_scales = DummyWeight(self._down_scales_name, torch.uint8)
59+
60+
self.gate_up_proj_bias = DummyWeight(self._gate_up_bias_name, torch.bfloat16)
61+
self.gate_up_proj_weight_blocks = DummyWeight(self._gate_up_blocks_name, torch.uint8)
62+
self.gate_up_proj_weight_scales = DummyWeight(self._gate_up_scales_name, torch.uint8)
63+
self.attn_sinks = DummyWeight(self._attn_sink_name, torch.bfloat16)
64+
65+
def _init_weight_names(self):
66+
super()._init_weight_names()
67+
68+
# Sinks
69+
self._attn_sink_name = f"model.layers.{self.layer_num_}.self_attn.sinks"
70+
71+
# Bias
72+
self._q_bias_name = f"model.layers.{self.layer_num_}.self_attn.q_proj.bias"
73+
self._k_bias_name = f"model.layers.{self.layer_num_}.self_attn.k_proj.bias"
74+
self._v_bias_name = f"model.layers.{self.layer_num_}.self_attn.v_proj.bias"
75+
self._o_bias_name = f"model.layers.{self.layer_num_}.self_attn.o_proj.bias"
76+
77+
# MOE Layers
78+
# model.layers.0.mlp.experts.down_proj_bias [32, 2 880]
79+
# model.layers.0.mlp.experts.down_proj_blocks [32, 2 880, 90, 16]
80+
# model.layers.0.mlp.experts.down_proj_scales [32, 2 880, 90]
81+
# model.layers.0.mlp.experts.gate_up_proj_bias [32, 5 760]
82+
# model.layers.0.mlp.experts.gate_up_proj_blocks [32, 5 760, 90, 16]
83+
# model.layers.0.mlp.experts.gate_up_proj_scales [32, 5 760, 90]
84+
# model.layers.0.mlp.router.bias [32]
85+
# model.layers.0.mlp.router.weight [32, 2 880]
86+
87+
self._router_bias_name = f"model.layers.{self.layer_num_}.mlp.router.bias"
88+
self._router_weight_name = f"model.layers.{self.layer_num_}.mlp.router.weight"
89+
90+
self._down_bias_name = f"model.layers.{self.layer_num_}.mlp.experts.down_proj_bias"
91+
self._down_blocks_name = f"model.layers.{self.layer_num_}.mlp.experts.down_proj_blocks"
92+
self._down_scales_name = f"model.layers.{self.layer_num_}.mlp.experts.down_proj_scales"
93+
self._down_weight_name = None
94+
95+
self._gate_up_bias_name = f"model.layers.{self.layer_num_}.mlp.experts.gate_up_proj_bias"
96+
self._gate_up_blocks_name = f"model.layers.{self.layer_num_}.mlp.experts.gate_up_proj_blocks"
97+
self._gate_up_scales_name = f"model.layers.{self.layer_num_}.mlp.experts.gate_up_proj_scales"
98+
self._gate_up_weight_name = None
99+
100+
def _init_ffn(self):
101+
self._init_moe()
102+
103+
def _post_weight_process(self):
104+
self.moe_intermediate_size = self.network_config_["intermediate_size"]
105+
self.split_inter_size = self.moe_intermediate_size // self.tp_world_size_
106+
107+
self.down_proj_weight = self._convert_moe_packed_tensors(
108+
blocks=self.down_proj_weight_blocks.weight,
109+
scales=self.down_proj_weight_scales.weight,
110+
dtype=torch.bfloat16,
111+
)[:, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :]
112+
# (32, 1440, 2880)
113+
114+
self.gate_up_proj_weight = self._convert_moe_packed_tensors(
115+
blocks=self.gate_up_proj_weight_blocks.weight,
116+
scales=self.gate_up_proj_weight_scales.weight,
117+
dtype=torch.bfloat16,
118+
) # (32, 2880, 5760)
119+
expert_num = self.gate_up_proj_weight.shape[0]
120+
self.gate_up_proj_weight = self.gate_up_proj_weight.reshape(expert_num, -1, 2, self.moe_intermediate_size)[
121+
:, :, :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)
122+
].reshape(expert_num, -1, 2*self.split_inter_size)
123+
# (32, 2880, 2880)
124+
125+
self.gate_up_proj_bias.weight = self.gate_up_proj_bias.weight.reshape(expert_num, 2, self.moe_intermediate_size)[
126+
:, :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)
127+
].reshape(expert_num, 2*self.split_inter_size)
128+
# (32, 2880)
129+
130+
def _convert_moe_packed_tensors(
131+
self,
132+
blocks,
133+
scales,
134+
*,
135+
dtype: torch.dtype = torch.bfloat16,
136+
rows_per_chunk: int = 32768 * 1024, # TODO these values are not here by mistake ;)
137+
) -> torch.Tensor:
138+
"""
139+
Convert the mxfp4 weights again, dequantizing and makes them compatible with the forward
140+
pass of GPT_OSS.
141+
"""
142+
import math
143+
144+
# Check if blocks and scales are on CPU, and move to GPU if so
145+
if not blocks.is_cuda and torch.cuda.is_available():
146+
blocks = blocks.cuda()
147+
scales = scales.cuda()
148+
149+
scales = scales.to(torch.int32) - 127 # TODO that's because 128=2**7
150+
151+
assert blocks.shape[:-1] == scales.shape, f"{blocks.shape[:-1]=} does not match {scales.shape=}"
152+
153+
lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
154+
155+
*prefix_shape, G, B = blocks.shape
156+
rows_total = math.prod(prefix_shape) * G
157+
158+
blocks = blocks.reshape(rows_total, B)
159+
scales = scales.reshape(rows_total, 1)
160+
161+
out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)
162+
163+
for r0 in range(0, rows_total, rows_per_chunk):
164+
r1 = min(r0 + rows_per_chunk, rows_total)
165+
166+
blk = blocks[r0:r1]
167+
exp = scales[r0:r1]
168+
169+
# nibble indices -> int64
170+
idx_lo = (blk & 0x0F).to(torch.long)
171+
idx_hi = (blk >> 4).to(torch.long)
172+
173+
sub = out[r0:r1]
174+
sub[:, 0::2] = lut[idx_lo]
175+
sub[:, 1::2] = lut[idx_hi]
176+
177+
torch.ldexp(sub, exp, out=sub)
178+
del idx_lo, idx_hi, blk, exp, sub
179+
180+
out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
181+
del blocks, scales, lut
182+
return out.transpose(1, 2).contiguous()

0 commit comments

Comments
 (0)