Skip to content

Commit fb242d9

Browse files
fix problem in review
1 parent 545b35e commit fb242d9

File tree

3 files changed

+221
-171
lines changed

3 files changed

+221
-171
lines changed
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import os
2+
import torch
3+
import threading
4+
from typing import Optional, Tuple, List, Dict, Any
5+
6+
from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe_weight_tp import FusedMoeWeightTP
7+
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id
8+
from lightllm.common.quantization import Quantcfg
9+
from lightllm.utils.log_utils import init_logger
10+
11+
logger = init_logger(__name__)
12+
13+
FP4_VALUES = [
14+
+0.0,
15+
+0.5,
16+
+1.0,
17+
+1.5,
18+
+2.0,
19+
+3.0,
20+
+4.0,
21+
+6.0,
22+
-0.0,
23+
-0.5,
24+
-1.0,
25+
-1.5,
26+
-2.0,
27+
-3.0,
28+
-4.0,
29+
-6.0,
30+
]
31+
32+
33+
class GPTOSSFusedMoeWeightTP(FusedMoeWeightTP):
34+
def __init__(
35+
self,
36+
gate_up_proj_name: str, # diff with FusedMoeWeightTP
37+
down_proj_name: str,
38+
e_score_correction_bias_name: str,
39+
weight_prefix: str,
40+
n_routed_experts: int,
41+
num_fused_shared_experts: int,
42+
split_inter_size: int,
43+
data_type: torch.dtype,
44+
network_config: Dict[str, Any],
45+
layer_num: int,
46+
world_size: int = 1, # diff with FusedMoeWeightTP
47+
quant_cfg: Quantcfg = None,
48+
) -> None:
49+
super().__init__(
50+
gate_up_proj_name,
51+
down_proj_name,
52+
gate_up_proj_name,
53+
e_score_correction_bias_name,
54+
weight_prefix,
55+
n_routed_experts,
56+
num_fused_shared_experts,
57+
split_inter_size,
58+
data_type,
59+
network_config,
60+
layer_num,
61+
quant_cfg,
62+
)
63+
self.hidden_size = network_config["hidden_size"]
64+
65+
self.alpha = 1.702
66+
self.limit = 7.0
67+
self.tp_world_size_ = world_size
68+
69+
self.w1_bias = None
70+
self.w2_bias = None
71+
72+
self._down_bias_name = f"{weight_prefix}.{down_proj_name}_bias"
73+
self._down_blocks_name = f"{weight_prefix}.{down_proj_name}_blocks"
74+
self._down_scales_name = f"{weight_prefix}.{down_proj_name}_scales"
75+
self._gate_up_bias_name = f"{weight_prefix}.{gate_up_proj_name}_bias"
76+
self._gate_up_blocks_name = f"{weight_prefix}.{gate_up_proj_name}_blocks"
77+
self._gate_up_scales_name = f"{weight_prefix}.{gate_up_proj_name}_scales"
78+
return
79+
80+
def _fuse_weight_scale(self):
81+
assert False, "Not implemented for GPT-OSS."
82+
83+
def _fuse(self):
84+
assert False, "Not implemented for GPT-OSS."
85+
86+
def load_hf_weights(self, weights):
87+
if (
88+
weights.get(self._down_blocks_name, None) is not None
89+
and weights.get(self._down_scales_name, None) is not None
90+
):
91+
w2 = self._convert_moe_packed_tensors(
92+
blocks=weights[self._down_blocks_name],
93+
scales=weights[self._down_scales_name],
94+
dtype=torch.bfloat16,
95+
)[:, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :]
96+
self.w2 = (self._cuda(w2), None)
97+
98+
if (
99+
weights.get(self._gate_up_blocks_name, None) is not None
100+
and weights.get(self._gate_up_scales_name, None) is not None
101+
):
102+
w1 = self._convert_moe_packed_tensors(
103+
blocks=weights[self._gate_up_blocks_name],
104+
scales=weights[self._gate_up_scales_name],
105+
dtype=torch.bfloat16,
106+
)[:, :, self.split_inter_size * self.tp_rank_ * 2 : self.split_inter_size * (self.tp_rank_ + 1) * 2]
107+
self.w1 = (self._cuda(w1), None)
108+
109+
if weights.get(self._gate_up_bias_name, None) is not None:
110+
w1_bias = weights[self._gate_up_bias_name][
111+
:, self.split_inter_size * self.tp_rank_ * 2 : self.split_inter_size * (self.tp_rank_ + 1) * 2
112+
]
113+
self.w1_bias = self._cuda(w1_bias)
114+
115+
if weights.get(self._down_bias_name, None) is not None:
116+
w2_bias = weights[self._down_bias_name]
117+
self.w2_bias = self._cuda(w2_bias)
118+
119+
def experts(self, hidden_states: torch.Tensor, routing_weights, layer_num):
120+
w1, w1_scale = self.w1
121+
w2, w2_scale = self.w2
122+
assert w1_scale is None and w2_scale is None, "For now, we do not support quantized weight in GPT-OSS."
123+
124+
batch_size = hidden_states.shape[0]
125+
hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
126+
num_experts = routing_weights.shape[1]
127+
128+
hidden_states = hidden_states.repeat(num_experts, 1)
129+
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
130+
gate_up = torch.bmm(hidden_states, w1) + self.w1_bias[..., None, :]
131+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
132+
gate = gate.clamp(min=None, max=self.limit)
133+
up = up.clamp(min=-self.limit, max=self.limit)
134+
glu = gate * torch.sigmoid(gate * self.alpha)
135+
next_states = torch.bmm(((up + 1) * glu), w2)
136+
next_states = next_states + self.w2_bias[..., None, :] / self.tp_world_size_
137+
next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
138+
next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
139+
next_states = next_states.sum(dim=0)
140+
return next_states
141+
142+
def _convert_moe_packed_tensors(
143+
self,
144+
blocks,
145+
scales,
146+
*,
147+
dtype: torch.dtype = torch.bfloat16,
148+
rows_per_chunk: int = 32768 * 1024,
149+
) -> torch.Tensor:
150+
"""
151+
Convert the mxfp4 weights again, dequantizing and makes them compatible with the forward
152+
pass of GPT_OSS.
153+
"""
154+
import math
155+
156+
# Check if blocks and scales are on CPU, and move to GPU if so
157+
if not blocks.is_cuda and torch.cuda.is_available():
158+
blocks = blocks.cuda()
159+
scales = scales.cuda()
160+
161+
scales = scales.to(torch.int32) - 127 # that's because 128=2**7
162+
163+
assert blocks.shape[:-1] == scales.shape, f"{blocks.shape[:-1]=} does not match {scales.shape=}"
164+
165+
lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
166+
167+
*prefix_shape, G, B = blocks.shape
168+
rows_total = math.prod(prefix_shape) * G
169+
170+
blocks = blocks.reshape(rows_total, B)
171+
scales = scales.reshape(rows_total, 1)
172+
173+
out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)
174+
175+
for r0 in range(0, rows_total, rows_per_chunk):
176+
r1 = min(r0 + rows_per_chunk, rows_total)
177+
178+
blk = blocks[r0:r1]
179+
exp = scales[r0:r1]
180+
181+
# nibble indices -> int64
182+
idx_lo = (blk & 0x0F).to(torch.long)
183+
idx_hi = (blk >> 4).to(torch.long)
184+
185+
sub = out[r0:r1]
186+
sub[:, 0::2] = lut[idx_lo]
187+
sub[:, 1::2] = lut[idx_hi]
188+
189+
torch.ldexp(sub, exp, out=sub)
190+
del idx_lo, idx_hi, blk, exp, sub
191+
192+
out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
193+
del blocks, scales, lut
194+
return out.transpose(1, 2).contiguous()

lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight
99
from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo
1010
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
11-
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
1211
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
1312
from lightllm.utils.log_utils import init_logger
1413

@@ -35,30 +34,6 @@ def _bind_norm(self):
3534
self._ffn_norm = self._ffn_norm
3635
return
3736

38-
def _experts(
39-
self, hidden_states: torch.Tensor, router_indices, routing_weights, layer_weight: GptOssTransformerLayerWeight
40-
):
41-
batch_size = hidden_states.shape[0]
42-
hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
43-
num_experts = routing_weights.shape[1]
44-
45-
hidden_states = hidden_states.repeat(num_experts, 1)
46-
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
47-
gate_up = (
48-
torch.bmm(hidden_states, layer_weight.gate_up_proj_weight)
49-
+ layer_weight.gate_up_proj_bias.weight[..., None, :]
50-
)
51-
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
52-
gate = gate.clamp(min=None, max=self.limit)
53-
up = up.clamp(min=-self.limit, max=self.limit)
54-
glu = gate * torch.sigmoid(gate * self.alpha)
55-
next_states = torch.bmm(((up + 1) * glu), layer_weight.down_proj_weight)
56-
next_states = next_states + layer_weight.down_proj_bias.weight[..., None, :]
57-
next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
58-
next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
59-
next_states = next_states.sum(dim=0)
60-
return next_states
61-
6237
def _att_norm(self, input, infer_state, layer_weight) -> torch.Tensor:
6338
out = self.alloc_tensor(input.shape, input.dtype)
6439
out = self._gpt_oss_rmsnorm(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_)
@@ -78,19 +53,17 @@ def _gpt_oss_rmsnorm(self, hidden_states, weight, eps=1e-6):
7853

7954
def _router(self, hidden_states, layer_weight: GptOssTransformerLayerWeight):
8055
hidden_states = hidden_states.reshape(-1, self.hidden_size)
81-
router_logits = layer_weight.moe_gate.mm(hidden_states) # (seq_len, num_experts)
82-
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
56+
router_logits = layer_weight.moe_gate.mm(hidden_states)
57+
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
8358
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
8459
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
8560
return router_scores, router_indices
8661

8762
def _ffn(
8863
self, input, infer_state: FlashAttentionStateInfo, layer_weight: GptOssTransformerLayerWeight
8964
) -> torch.Tensor:
90-
router_scores, router_indices = self._router(input, layer_weight) # (num_experts, seq_len)
91-
routed_out = self._experts(
92-
input, router_indices=router_indices, routing_weights=router_scores, layer_weight=layer_weight
93-
)
65+
router_scores, _ = self._router(input, layer_weight)
66+
routed_out = layer_weight.experts.experts(input, routing_weights=router_scores, layer_num=self.layer_num_)
9467
return routed_out
9568

9669
def _conext_sliding_attention_flashattention(

0 commit comments

Comments
 (0)