Skip to content

Commit f3c2b68

Browse files
committed
moe support
1 parent 9573de8 commit f3c2b68

File tree

5 files changed

+232
-15
lines changed

5 files changed

+232
-15
lines changed
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
"""
2+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
from typing import Dict
18+
19+
import paddle
20+
from paddle import nn
21+
22+
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import (
23+
UnquantizedFusedMoEMethod,
24+
)
25+
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
26+
from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig
27+
from fastdeploy.model_executor.ops.npu import npu_quant_weight
28+
29+
30+
class NPUMoEMethod(UnquantizedFusedMoEMethod):
31+
"""
32+
NPU MOE
33+
"""
34+
35+
def process_loaded_weights(self, layer: nn.Layer, state_dict):
36+
37+
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
38+
for weights in [up_gate_proj_weights, down_proj_weights]:
39+
for idx, weight in enumerate(weights):
40+
weights[idx] = weight.transpose([1, 0])
41+
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
42+
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
43+
44+
layer.up_gate_proj_weight.set_value(stacked_up_gate_proj_weights)
45+
layer.down_proj_weight.set_value(stacked_down_proj_weights)
46+
47+
def apply_tp(
48+
self,
49+
layer: nn.Layer,
50+
x: paddle.Tensor,
51+
gate: nn.Layer,
52+
) -> paddle.Tensor:
53+
"""
54+
Paddle Cutlass compute Fused MoE.
55+
"""
56+
from fastdeploy.model_executor.ops.npu import fused_sparse_moe
57+
fused_moe_out = fused_sparse_moe(
58+
x,
59+
gate.weight.transpose([1, 0]),
60+
layer.up_gate_proj_weight,
61+
layer.down_proj_weight,
62+
None, # ffn1_bias
63+
None, # ffn1_scale
64+
None, # ffn2_bias
65+
None, # ffn2_scale
66+
self.moe_quant_type,
67+
layer.top_k,
68+
layer.tp_size
69+
)
70+
if layer.tp_size > 1:
71+
from fastdeploy.distributed.communication import (
72+
tensor_model_parallel_all_reduce,
73+
)
74+
75+
tensor_model_parallel_all_reduce(fused_moe_out)
76+
77+
return fused_moe_out
78+
79+
def apply_ep_prefill(
80+
self,
81+
layer: nn.Layer,
82+
x: paddle.Tensor,
83+
gate: nn.Layer,
84+
) -> paddle.Tensor:
85+
"""
86+
Apply the EP prefill method.
87+
"""
88+
raise NotImplementedError
89+
90+
def apply_ep_decode(
91+
self,
92+
layer: nn.Layer,
93+
x: paddle.Tensor,
94+
gate: nn.Layer,
95+
) -> paddle.Tensor:
96+
"""
97+
Apply the EP decoder method.
98+
"""
99+
raise NotImplementedError
100+
101+
102+
class NPUWeightOnlyMoEMethod(QuantMethodBase):
103+
"""
104+
NPU Fused MoE Method.
105+
"""
106+
107+
def __init__(
108+
self,
109+
quant_config: WeightOnlyConfig,
110+
) -> None:
111+
super().__init__()
112+
self.quant_config = quant_config
113+
self.moe_quant_type = self.quant_config.algo
114+
115+
def create_weights(self, layer: nn.Layer, state_dict: Dict[str, paddle.Tensor]):
116+
"""
117+
Paddle cutlass create weight process.
118+
"""
119+
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
120+
assert len(up_gate_proj_weights) == layer.num_local_experts
121+
assert len(down_proj_weights) == layer.num_local_experts
122+
assert up_gate_proj_weights[0].shape == [
123+
layer.hidden_size,
124+
layer.moe_intermediate_size * 2,
125+
]
126+
assert down_proj_weights[0].shape == [
127+
layer.moe_intermediate_size,
128+
layer.hidden_size,
129+
]
130+
131+
added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
132+
added_scale_attrs = [
133+
"up_gate_proj_weight_scale",
134+
"down_proj_weight_scale",
135+
]
136+
137+
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
138+
weight_name = added_weight_attrs[idx]
139+
scale_name = added_scale_attrs[idx]
140+
141+
weight_list = []
142+
weight_scale_list = []
143+
for i in range(layer.num_local_experts):
144+
quant_weight, scale = npu_quant_weight(
145+
weight_tensor[i], self.moe_quant_type, -1, -1
146+
) # weight is [k,n]
147+
weight_list.append(quant_weight.transpose([1, 0])) # transpose weight to [n,k]
148+
weight_scale_list.append(scale)
149+
quanted_weight = paddle.stack(weight_list, axis=0)
150+
setattr(
151+
layer,
152+
weight_name,
153+
layer.create_parameter(
154+
shape=quanted_weight.shape,
155+
dtype=quanted_weight.dtype,
156+
default_initializer=paddle.nn.initializer.Constant(0),
157+
),
158+
)
159+
getattr(layer, weight_name).set_value(quanted_weight)
160+
161+
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
162+
setattr(
163+
layer,
164+
scale_name,
165+
layer.create_parameter(
166+
shape=quanted_weight_scale.shape,
167+
dtype=quanted_weight_scale.dtype,
168+
),
169+
)
170+
getattr(layer, scale_name).set_value(quanted_weight_scale)
171+
172+
def apply(
173+
self,
174+
layer: nn.Layer,
175+
x: paddle.Tensor,
176+
gate: nn.Layer,
177+
) -> paddle.Tensor:
178+
"""
179+
NPU compute Fused MoE.
180+
"""
181+
from fastdeploy.model_executor.ops.npu import fused_sparse_moe
182+
fused_moe_out = fused_sparse_moe(
183+
x,
184+
gate.weight.transpose([1, 0]),
185+
layer.up_gate_proj_weight,
186+
layer.down_proj_weight,
187+
None, # ffn1_bias
188+
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
189+
None, # ffn2_bias
190+
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
191+
self.moe_quant_type,
192+
layer.top_k,
193+
layer.tp_size
194+
)
195+
if layer.tp_size > 1:
196+
from fastdeploy.distributed.communication import (
197+
tensor_model_parallel_all_reduce,
198+
)
199+
200+
tensor_model_parallel_all_reduce(fused_moe_out)
201+
202+
return fused_moe_out

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def get_moe_method():
5555
)
5656

5757
return MetaxTritonWeightOnlyMoEMethod(None)
58+
elif current_platform.is_npu():
59+
from .fused_moe_npu_backend import NPUMoEMethod
60+
61+
return NPUMoEMethod(None)
5862
raise NotImplementedError
5963

6064

fastdeploy/model_executor/layers/quantization/weight_only.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,11 @@ def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
105105

106106
return GPUWeightOnlyLinearMethod(self)
107107
elif current_platform.is_npu():
108-
from fastdeploy.model_executor.layers.backends import NPUWeightOnlyLinearMethod
109-
return NPUWeightOnlyLinearMethod(self)
108+
from fastdeploy.model_executor.layers.backends import (NPUWeightOnlyLinearMethod, NPUWeightOnlyMoEMethod)
109+
if isinstance(layer, FusedMoe):
110+
return NPUWeightOnlyMoEMethod(self)
111+
else:
112+
return NPUWeightOnlyLinearMethod(self)
110113
else:
111114
if isinstance(layer, FusedMoE):
112115
if layer.use_method == "cutlass":

fastdeploy/model_executor/ops/npu/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .get_token_penalty_multi_scores import get_token_penalty_multi_scores_npu
2929
from .top_p_sampling import top_p_sampling_npu
3030
from .weight_quantize import npu_quant_weight
31+
from .sparse_moe import fused_sparse_moe
3132

3233
PACKAGE = "fastdeploy.model_executor.ops.npu"
3334

fastdeploy/model_executor/ops/npu/sparse_moe.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
import inspect
22

33
import paddle
4-
import paddlenlp_ops
54
from paddle.base import core
5+
import inspect
6+
from paddlenlp_ops import sparse_moe
7+
8+
69

710

811
# npu interface refer to gpu interface
912
def fused_sparse_moe(
1013
input,
1114
gate_weight,
12-
ffn1_weight,
13-
ffn2_weight,
15+
ffn1_weight,
16+
ffn2_weight,
1417
ffn1_bias,
1518
ffn1_scale,
1619
ffn2_bias,
@@ -22,27 +25,31 @@ def fused_sparse_moe(
2225
"""
2326
call npu func to implement this function
2427
"""
25-
ffn1_weight = paddle.cast(ffn1_weight, paddle.bfloat16)
26-
ffn2_weight = paddle.cast(ffn2_weight, paddle.bfloat16)
2728

29+
gate_weight = paddle.cast(gate_weight, paddle.bfloat16)
30+
31+
# ffn1_weight = paddle.cast(ffn1_weight, paddle.bfloat16)
32+
ffn1_weight = paddle.transpose(ffn1_weight, [0, 2, 1])
33+
# ffn2_weight = paddle.cast(ffn2_weight, paddle.bfloat16)
34+
ffn2_weight = paddle.transpose(ffn2_weight, [0, 2, 1])
2835

29-
gate_weight = gate_weight.transpose([1, 0]).astype(input.dtype)
3036

3137
temp = paddle.zeros([1]).astype(input.dtype)
3238

39+
3340
expert_array = paddle.arange(moe_topk * input.shape[0]).astype("int32")
3441
expert_group = paddle.ones([1]).astype("int32")
3542
one_hot = paddle.ones([1]).astype("int32")
3643
zero_hot = paddle.zeros([1]).astype("int32")
3744

38-
# define quant mapping: may modify
3945
if quant_method == "weight_int4_only":
4046
quanttype = 11
4147
elif quant_method == "weight_int8_only":
4248
quanttype = 6
4349
else:
4450
quanttype = 1
45-
y = paddlenlp_ops.sparse_moe(
51+
52+
y = sparse_moe(
4653
input,
4754
gate_weight,
4855
temp,
@@ -51,24 +58,24 @@ def fused_sparse_moe(
5158
temp,
5259
temp,
5360
ffn1_weight,
54-
ffn1_bias if ffn1_bias else temp,
61+
ffn1_bias if ffn1_bias is not None else temp,
5562
temp,
5663
temp,
57-
ffn1_scale,
64+
ffn1_scale if ffn1_scale is not None else temp,
5865
temp,
5966
ffn2_weight,
60-
ffn2_bias if ffn2_bias else temp,
67+
ffn2_bias if ffn2_bias is not None else temp,
6168
temp,
6269
temp,
63-
ffn2_scale,
70+
ffn2_scale if ffn2_scale is not None else temp,
6471
temp,
6572
expert_array,
6673
expert_group,
6774
one_hot,
6875
zero_hot,
6976
moe_topk,
7077
input.dtype == paddle.bfloat16,
71-
tp_size,
78+
tp_size,
7279
quanttype,
7380
)
7481
return y

0 commit comments

Comments
 (0)