Skip to content

Commit ce1f353

Browse files
authored
Move create_parameters to __init__ in FuseMOE for CultassBackend and TritonBackend (#3148)
* w4a8 bug * fix w4a8 bug * remove code * modify the triton backend * fix ep * fix the bug with tensor_wise_fp8 in triton backend * fix the RL * fix bug by merge * fix the bug in w4a8 * fix the tensor_wise_fp8 bug * fix RL
1 parent d0e9a70 commit ce1f353

File tree

10 files changed

+444
-83
lines changed

10 files changed

+444
-83
lines changed

fastdeploy/engine/expert_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def __init__(self, cfg, local_data_parallel_id):
5959
self.cfg.disaggregate_info = None
6060

6161
self.scheduler = cfg.scheduler_config.scheduler()
62-
63-
self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
62+
if cfg.splitwise_role != "mixed":
63+
self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
6464

6565
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
6666

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
18+
from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import (
19+
CutlassMoEMethod,
20+
)
21+
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import (
22+
BlockWiseFP8MoEMethod,
23+
TensorWiseFP8MoEMethod,
24+
TritonWeightOnlyMoEMethod,
25+
)
26+
27+
pre_create_weights_list = (CutlassMoEMethod, TensorWiseFP8MoEMethod, BlockWiseFP8MoEMethod, TritonWeightOnlyMoEMethod)
28+
29+
30+
def is_supported_moe_backend(quant_method: MoEMethodBase):
31+
return isinstance(quant_method, pre_create_weights_list)

fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import paddle
2020
from paddle import nn
2121

22-
from fastdeploy.model_executor.models.utils import set_weight_attrs
22+
from fastdeploy.model_executor.layers.utils import set_weight_attrs
2323
from fastdeploy.platforms import current_platform
2424

2525
from ..quantization.quant_base import QuantMethodBase

fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py

Lines changed: 177 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
2424
from fastdeploy.platforms import current_platform
2525

26-
from ..utils import create_and_set_parameter, get_tensor
26+
from ..utils import get_tensor
2727
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
2828

2929
if current_platform.is_cuda():
@@ -202,7 +202,10 @@ def apply_ep_decode(
202202
gate_out = gate(x.cast("float32"))
203203
# 1. Select topk experts and weights
204204
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
205-
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
205+
expertwise_scale = None
206+
if hasattr(layer, "up_gate_proj_in_scale_all_experts"): # only use in w4a8
207+
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
208+
206209
# 2. EP Dispatch
207210
permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch(
208211
x, topk_idx, topk_weights, expertwise_scale=expertwise_scale
@@ -382,12 +385,48 @@ def process_prequanted_weights(self, layer: nn.Layer, state_dict):
382385
"down_proj_in_scale": down_proj_in_scale,
383386
}
384387
for name, tensor in name_tensor_map.items():
385-
create_and_set_parameter(layer, name, tensor)
388+
getattr(layer, name).set_value(tensor)
386389

387-
def create_weights(self, layer: nn.Layer, state_dict):
390+
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
388391
"""
389392
Paddle cutlass create weight process.
390393
"""
394+
self.weight_dtype = "int8"
395+
self.ffn1_weight_shape = [
396+
layer.num_local_experts,
397+
layer.hidden_size // 2,
398+
layer.moe_intermediate_size * 2,
399+
]
400+
self.ffn2_weight_shape = [
401+
layer.num_local_experts,
402+
layer.moe_intermediate_size // 2,
403+
layer.hidden_size,
404+
]
405+
setattr(
406+
layer,
407+
self.added_weight_attrs[0],
408+
layer.create_parameter(
409+
shape=self.ffn1_weight_shape,
410+
dtype=self.weight_dtype,
411+
default_initializer=paddle.nn.initializer.Constant(0),
412+
),
413+
)
414+
setattr(
415+
layer,
416+
self.added_weight_attrs[1],
417+
layer.create_parameter(
418+
shape=self.ffn2_weight_shape,
419+
dtype=self.weight_dtype,
420+
default_initializer=paddle.nn.initializer.Constant(0),
421+
),
422+
)
423+
424+
self.create_w4a8_scale_weights(layer, layer.weight_key_map)
425+
426+
def process_loaded_weights(self, layer: nn.Layer, state_dict):
427+
"""
428+
Paddle cutlass load weight process.
429+
"""
391430
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
392431
self.check(layer, up_gate_proj_weights, down_proj_weights)
393432
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
@@ -397,11 +436,63 @@ def create_weights(self, layer: nn.Layer, state_dict):
397436
quant_weight, scale = weight_quantize(weight_tensor[i], algo=self.moe_quant_type, arch=80)
398437
weight_list.append(quant_weight)
399438
quanted_weight = paddle.stack(weight_list, axis=0)
400-
create_and_set_parameter(layer, weight_name, quanted_weight)
439+
getattr(layer, weight_name).set_value(quanted_weight)
401440

402-
self.create_w4a8_scale_weights(layer, layer.weight_key_map, state_dict)
441+
self.load_w4a8_scale_weights(layer, layer.weight_key_map, state_dict)
403442

404-
def create_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict, state_dict: dict):
443+
def create_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict):
444+
"""
445+
Get w4a8 weights from state dict and process them.
446+
Args:
447+
layer (nn.Layer): The layer to add parameters to.
448+
weight_key_map (dict): The weight key map.
449+
state_dict (dict): The state dict.
450+
"""
451+
self.default_dtype = layer._helper.get_default_dtype()
452+
if layer.ep_size > 1:
453+
setattr(
454+
layer,
455+
"up_gate_proj_in_scale_all_experts",
456+
layer.create_parameter(
457+
shape=[layer.num_experts],
458+
dtype="float32",
459+
default_initializer=paddle.nn.initializer.Constant(0),
460+
),
461+
)
462+
463+
# in_scales
464+
for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]:
465+
setattr(
466+
layer,
467+
in_scale_name,
468+
layer.create_parameter(
469+
shape=[layer.num_local_experts],
470+
dtype="float32",
471+
default_initializer=paddle.nn.initializer.Constant(0),
472+
),
473+
)
474+
475+
# weight_scales
476+
setattr(
477+
layer,
478+
"up_gate_proj_weight_scale",
479+
layer.create_parameter(
480+
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
481+
dtype=self.default_dtype,
482+
default_initializer=paddle.nn.initializer.Constant(0),
483+
),
484+
)
485+
setattr(
486+
layer,
487+
"down_proj_weight_scale",
488+
layer.create_parameter(
489+
shape=[layer.num_local_experts, layer.hidden_size],
490+
dtype=self.default_dtype,
491+
default_initializer=paddle.nn.initializer.Constant(0),
492+
),
493+
)
494+
495+
def load_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict, state_dict: dict):
405496
"""
406497
Get w4a8 weights from state dict and process them.
407498
Args:
@@ -415,7 +506,7 @@ def _extract_scale_tensor(state_dict, key_template, expert_idx):
415506

416507
def _process_in_scale(name: str, in_scales: list[paddle.Tensor]):
417508
processed_in_scale = 1 / paddle.concat(in_scales)
418-
create_and_set_parameter(layer, name, processed_in_scale)
509+
getattr(layer, name).set_value(processed_in_scale)
419510
return processed_in_scale
420511

421512
def _process_weight_scale(
@@ -426,7 +517,7 @@ def _process_weight_scale(
426517
processed_weight_scale = (
427518
paddle.stack(weight_scales, axis=0) / (127 * 112) / processed_in_scale[:, None]
428519
).cast(paddle.get_default_dtype())
429-
create_and_set_parameter(layer, name, processed_weight_scale)
520+
getattr(layer, name).set_value(processed_weight_scale)
430521

431522
# 1. Init scale containers and maps
432523
up_gate_proj_weight_scales = []
@@ -456,8 +547,8 @@ def _process_weight_scale(
456547
for expert_idx in range(layer.num_experts):
457548
scale_tensor = get_tensor(state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)])
458549
up_gate_proj_in_scales_all_experts.append(1 / scale_tensor)
459-
create_and_set_parameter(
460-
layer, "up_gate_proj_in_scale_all_experts", paddle.concat(up_gate_proj_in_scales_all_experts)
550+
getattr(layer, "up_gate_proj_in_scale_all_experts").set_value(
551+
paddle.concat(up_gate_proj_in_scales_all_experts)
461552
)
462553

463554
for local_expert_idx in range(layer.num_local_experts):
@@ -527,15 +618,85 @@ def process_prequanted_weights(self, layer: nn.Layer, state_dict):
527618
"down_proj_weight_scale": down_proj_weight_scale,
528619
}
529620
for name, tensor in name_tensor_map.items():
530-
create_and_set_parameter(layer, name, tensor)
621+
getattr(layer, name).set_value(tensor)
531622

532-
def create_weights(self, layer: nn.Layer, state_dict):
623+
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
533624
"""
534625
Paddle cutlass create weight process.
535626
"""
627+
self.default_dtype = layer._helper.get_default_dtype()
628+
self.weight_dtype = "int8"
629+
630+
up_gate_proj_weight_name = self.added_weight_attrs[0]
631+
down_proj_weight_name = self.added_weight_attrs[1]
632+
if self.moe_quant_type == "weight_only_int4":
633+
self.ffn1_weight_shape = [
634+
layer.num_local_experts,
635+
layer.moe_intermediate_size,
636+
layer.hidden_size,
637+
]
638+
else:
639+
self.ffn1_weight_shape = [
640+
layer.num_local_experts,
641+
layer.moe_intermediate_size * 2,
642+
layer.hidden_size,
643+
]
644+
if self.moe_quant_type == "weight_only_int4":
645+
self.ffn2_weight_shape = [
646+
layer.num_local_experts,
647+
layer.hidden_size // 2,
648+
layer.moe_intermediate_size,
649+
]
650+
else:
651+
self.ffn2_weight_shape = [
652+
layer.num_local_experts,
653+
layer.hidden_size,
654+
layer.moe_intermediate_size,
655+
]
656+
setattr(
657+
layer,
658+
up_gate_proj_weight_name,
659+
layer.create_parameter(
660+
shape=self.ffn1_weight_shape,
661+
dtype=self.weight_dtype,
662+
default_initializer=paddle.nn.initializer.Constant(0),
663+
),
664+
)
665+
setattr(
666+
layer,
667+
down_proj_weight_name,
668+
layer.create_parameter(
669+
shape=self.ffn2_weight_shape,
670+
dtype=self.weight_dtype,
671+
default_initializer=paddle.nn.initializer.Constant(0),
672+
),
673+
)
674+
# weight_scale
675+
setattr(
676+
layer,
677+
self.added_scale_attrs[0],
678+
layer.create_parameter(
679+
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
680+
dtype=self.default_dtype,
681+
default_initializer=paddle.nn.initializer.Constant(0),
682+
),
683+
)
684+
setattr(
685+
layer,
686+
self.added_scale_attrs[1],
687+
layer.create_parameter(
688+
shape=[layer.num_local_experts, layer.hidden_size],
689+
dtype=self.default_dtype,
690+
default_initializer=paddle.nn.initializer.Constant(0),
691+
),
692+
)
693+
694+
def process_loaded_weights(self, layer: nn.Layer, state_dict):
695+
"""
696+
Paddle cutlass load weight process.
697+
"""
536698
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
537699
self.check(layer, up_gate_proj_weights, down_proj_weights)
538-
539700
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
540701
weight_name = self.added_weight_attrs[idx]
541702
scale_name = self.added_scale_attrs[idx]
@@ -547,7 +708,7 @@ def create_weights(self, layer: nn.Layer, state_dict):
547708
weight_list.append(quant_weight)
548709
weight_scale_list.append(scale)
549710
quanted_weight = paddle.stack(weight_list, axis=0)
550-
create_and_set_parameter(layer, weight_name, quanted_weight)
711+
getattr(layer, weight_name).set_value(quanted_weight)
551712

552713
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
553-
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
714+
getattr(layer, scale_name).set_value(quanted_weight_scale)

0 commit comments

Comments
 (0)