Skip to content

Commit ece3d13

Browse files
fix
1 parent 711b730 commit ece3d13

File tree

4 files changed

+26
-69
lines changed

4 files changed

+26
-69
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py

Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -116,26 +116,13 @@ def load_hf_weights(self, weights):
116116
w2_bias = weights[self._down_bias_name]
117117
self.w2_bias = self._cuda(w2_bias)
118118

119-
# Keep torch version code for reference
120-
def _torch_router(self, router_logits, top_k, layer_num):
119+
def router(self, router_logits, top_k):
121120
router_top_value, router_indices = torch.topk(router_logits, top_k, dim=-1)
122121
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
123-
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
124-
return router_scores, router_indices
122+
return router_top_value, router_indices
125123

126124
def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group):
127-
from lightllm.common.fused_moe.topk_select import select_experts
128-
129-
topk_weights, topk_ids = select_experts(
130-
hidden_states=input_tensor,
131-
router_logits=router_logits,
132-
correction_bias=self.e_score_correction_bias,
133-
use_grouped_topk=use_grouped_topk,
134-
top_k=top_k,
135-
renormalize=renormalize,
136-
topk_group=topk_group,
137-
num_expert_group=num_expert_group,
138-
)
125+
topk_weights, topk_ids = self.router(router_logits, top_k)
139126

140127
w1, w1_scale = self.w1
141128
w2, w2_scale = self.w2
@@ -161,29 +148,6 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t
161148
)
162149
return output_tensor
163150

164-
def _torch_experts(self, hidden_states: torch.Tensor, routing_weights, layer_num):
165-
w1, w1_scale = self.w1
166-
w2, w2_scale = self.w2
167-
assert w1_scale is None and w2_scale is None, "For now, we do not support quantized weight in GPT-OSS."
168-
169-
batch_size = hidden_states.shape[0]
170-
hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
171-
num_experts = routing_weights.shape[1]
172-
173-
hidden_states = hidden_states.repeat(num_experts, 1)
174-
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
175-
gate_up = torch.bmm(hidden_states, w1.transpose(1, 2)) + self.w1_bias[..., None, :]
176-
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
177-
gate = gate.clamp(min=None, max=self.limit)
178-
up = up.clamp(min=-self.limit, max=self.limit)
179-
glu = gate * torch.sigmoid(gate * self.alpha)
180-
next_states = torch.bmm(((up + 1) * glu), w2.transpose(1, 2))
181-
next_states = next_states + self.w2_bias[..., None, :] / self.tp_world_size_
182-
next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
183-
next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
184-
next_states = next_states.sum(dim=0)
185-
return next_states
186-
187151
def _convert_moe_packed_tensors(
188152
self,
189153
blocks,

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,9 @@ def grouped_matmul_kernel(
394394
weight_stride_0,
395395
weight_stride_1,
396396
weight_stride_2,
397+
bias_ptr, # [expert_num, n]
398+
bias_stride_0,
399+
bias_stride_1,
397400
expert_to_weights_ptr, # [expert_num, token_num * topk]
398401
expert_to_weights_stride0,
399402
expert_to_weights_stride1,
@@ -418,9 +421,6 @@ def grouped_matmul_kernel(
418421
MUL_ROUTED_WEIGHT: tl.constexpr = False,
419422
NEED_K_MASK: tl.constexpr = True,
420423
NEED_TRANS: tl.constexpr = False,
421-
# Bias support
422-
bias_ptr=None, # [expert_num, n]
423-
bias_stride_0=0,
424424
ADD_BIAS: tl.constexpr = False,
425425
):
426426
pid = tl.program_id(0)
@@ -535,7 +535,7 @@ def grouped_matmul_kernel(
535535
if ADD_BIAS:
536536
offs_bn_bias = offs_bn # [BLOCK_SIZE_N]
537537
bias_ptrs = bias_ptr + expert_id * bias_stride_0 + offs_bn_bias
538-
bias_vals = tl.load(bias_ptrs, mask=offs_bn_bias < n, other=0.0) # [BLOCK_SIZE_N]
538+
bias_vals = tl.load(bias_ptrs) # [BLOCK_SIZE_N]
539539
accumulator += bias_vals[None, :] # broadcast across M dimension
540540

541541
if MUL_ROUTED_WEIGHT:
@@ -728,6 +728,9 @@ def grouped_matmul(
728728
expert_weights.stride(0),
729729
expert_weights.stride(1),
730730
expert_weights.stride(2),
731+
bias,
732+
bias.stride(0) if bias is not None else 0,
733+
bias.stride(1) if bias is not None and bias.ndim >= 2 else 0,
731734
expert_to_weights,
732735
expert_to_weights.stride(0),
733736
expert_to_weights.stride(1),
@@ -753,8 +756,6 @@ def grouped_matmul(
753756
num_warps=num_warps,
754757
num_stages=num_stages,
755758
ADD_BIAS=bias is not None,
756-
bias_ptr=bias,
757-
bias_stride_0=bias.stride(0) if bias is not None else 0,
758759
)
759760
return (mblocks_to_expert_id, mblocks_to_m_index, BLOCK_SIZE_M)
760761

lightllm/common/fused_moe/moe_silu_and_mul.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@ def _silu_and_mul_kernel_fast(
1616
stride_output_n,
1717
size_m,
1818
size_n,
19+
limit: tl.constexpr,
20+
alpha: tl.constexpr,
1921
BLOCK_M: tl.constexpr,
2022
BLOCK_N: tl.constexpr,
2123
NUM_STAGES: tl.constexpr,
2224
NEED_MASK: tl.constexpr,
2325
layout: tl.constexpr = "blocked", # "blocked" or "interleaved"
24-
limit=None,
25-
alpha=None,
26+
USE_LIMIT_AND_ALPHA: tl.constexpr = False,
2627
):
2728
stride_input_m = tl.cast(stride_input_m, dtype=tl.int64)
2829
stride_output_m = tl.cast(stride_output_m, dtype=tl.int64)
@@ -63,27 +64,23 @@ def _silu_and_mul_kernel_fast(
6364
other=other,
6465
).to(tl.float32)
6566

66-
if limit is None and alpha is None:
67-
gate = gate / (1 + tl.exp(-gate))
67+
if USE_LIMIT_AND_ALPHA:
68+
gate = tl.minimum(gate, limit)
69+
up = tl.minimum(tl.maximum(up, -limit), limit)
70+
gate = 1 / (1 + tl.exp(-gate * alpha)) * gate
6871
gate = gate.to(input_ptr.dtype.element_ty)
69-
7072
tl.store(
7173
output_ptr + out_offsets,
72-
up * gate,
74+
(up + 1) * gate,
7375
mask=mask,
7476
)
7577
else:
76-
# clamp up and gate
77-
if limit is not None:
78-
gate = tl.minimum(gate, limit)
79-
up = tl.minimum(tl.maximum(up, -limit), limit)
80-
if alpha is None:
81-
alpha = 1.0
82-
gate = 1 / (1 + tl.exp(-gate * alpha)) * gate
78+
gate = gate / (1 + tl.exp(-gate))
8379
gate = gate.to(input_ptr.dtype.element_ty)
80+
8481
tl.store(
8582
output_ptr + out_offsets,
86-
(up + 1) * gate,
83+
up * gate,
8784
mask=mask,
8885
)
8986

@@ -114,6 +111,7 @@ def silu_and_mul_fwd(
114111
):
115112
assert input.is_contiguous()
116113
assert output.is_contiguous()
114+
assert (limit is None and alpha is None) or (limit is not None and alpha is not None)
117115

118116
stride_input_m = input.stride(0)
119117
stride_input_n = input.stride(1)
@@ -132,6 +130,7 @@ def silu_and_mul_fwd(
132130
# limit the grid size to avoid the invalid argument error of triton
133131
while triton.cdiv(size_m, BLOCK_M) > 8192:
134132
BLOCK_M *= 2
133+
USE_LIMIT_AND_ALPHA = limit is not None and alpha is not None
135134

136135
grid = (
137136
triton.cdiv(size_n, BLOCK_N),
@@ -147,13 +146,14 @@ def silu_and_mul_fwd(
147146
stride_output_n=stride_output_n,
148147
size_m=size_m,
149148
size_n=size_n,
149+
limit=limit,
150+
alpha=alpha,
150151
BLOCK_M=BLOCK_M,
151152
BLOCK_N=BLOCK_N,
152153
NUM_STAGES=NUM_STAGES,
153154
NEED_MASK=NEED_MASK,
154155
num_warps=num_warps,
155156
layout=layout,
156-
limit=limit,
157-
alpha=alpha,
157+
USE_LIMIT_AND_ALPHA=USE_LIMIT_AND_ALPHA,
158158
)
159159
return

lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,6 @@ def _gpt_oss_rmsnorm(self, hidden_states, weight, eps=1e-6):
5151
hidden_states = hidden_states * torch.rsqrt(variance + eps)
5252
return (weight * hidden_states).to(input_dtype) # main diff with Llama
5353

54-
def _torch_router(self, hidden_states, layer_weight: GptOssTransformerLayerWeight):
55-
hidden_states = hidden_states.reshape(-1, self.hidden_size)
56-
router_logits = layer_weight.moe_gate.mm(hidden_states)
57-
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
58-
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
59-
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
60-
return router_scores, router_indices
61-
6254
def _ffn(
6355
self, input, infer_state: FlashAttentionStateInfo, layer_weight: GptOssTransformerLayerWeight
6456
) -> torch.Tensor:

0 commit comments

Comments
 (0)