Skip to content

Commit 03db555

Browse files
authored
[DeepSeek R1] add chunk moe args on deepseek_r1 (#1834)
add chunk moe for FP8 deepseek_r1 inference @czhu15 @Wei-Lin-Intel @yiliu30 @hlin99 work with intel/neural-compressor#2270
1 parent 17ef2d1 commit 03db555

File tree

3 files changed

+91
-1
lines changed

3 files changed

+91
-1
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,7 @@ def __init__(
495495
experts_min, experts_max = 0, self.local_num_experts
496496
moe_op = VllmMixtureOfExpertsOpFP8(
497497
num_expert_per_group,
498+
self.global_num_experts,
498499
experts_min + ep_shift,
499500
experts_max - 1 + ep_shift,
500501
)

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,29 @@ def __init__(self, quant_config: Fp8Config):
516516
self.enable_dmoe_dynamic_scale = os.environ.get("VLLM_DMOE_DYNAMIC_SCALE", False) in ["1", "true"]
517517
self.use_static_moe = os.environ.get("VLLM_USE_STATIC_MOE", "0") in ["1", "true"]
518518
self.optimize_with_partial_experts = os.environ.get("VLLM_OPTIMIZE_WITH_PARTIAL_EXPERTS", "0") in ["1", "true"]
519+
self.enable_moe_chunk = os.environ.get('VLLM_SUPPORT_MOE_CHUNK',
520+
'false').lower() == 'true'
521+
self.chunk_size_list = [
522+
int(x)
523+
for x in os.environ.get(
524+
"PT_HPU_MOE_CHUNK", "64,128,512,1024,1536,2048,4096"
525+
).split(",")
526+
if x.strip()
527+
]
528+
self.token_boundary_list = [
529+
int(x)
530+
for x in os.environ.get(
531+
"PT_HPU_MOE_TOKEN_BOUNDARY", "64,64,1536,1536,2048,2048,4096"
532+
).split(",")
533+
if x.strip()
534+
]
535+
assert len(self.chunk_size_list) == len(self.token_boundary_list), (
536+
f"chunk_size_list({len(self.chunk_size_list)}) and "
537+
f"token_boundary_list({len(self.token_boundary_list)}) must be the same length"
538+
)
539+
if self.enable_moe_chunk:
540+
logger.info("token_boundary_list is:%s",self.token_boundary_list)
541+
logger.info("chunk_size_list is:%s",self.chunk_size_list)
519542

520543
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
521544
intermediate_size_per_partition: int,
@@ -1043,6 +1066,17 @@ def do_dynamic_moe_with_static_scaling(x, topk_ids, topk_weights, w13_weight_fp8
10431066
topk_weights_across_dp)
10441067

10451068
batched_tokens = x.shape[0]
1069+
kwargs = {}
1070+
if self.enable_moe_chunk:
1071+
chunk_size = self.chunk_size_list[-1]
1072+
for idx, threshold in enumerate(self.token_boundary_list):
1073+
if batched_tokens <= threshold:
1074+
chunk_size = self.chunk_size_list[idx]
1075+
break
1076+
kwargs = {
1077+
"chunk_size": chunk_size,
1078+
"total_experts": 256,
1079+
}
10461080

10471081
if batched_tokens > self.moe_slice_length:
10481082
final_hidden_states_list = []
@@ -1066,6 +1100,7 @@ def do_dynamic_moe_with_static_scaling(x, topk_ids, topk_weights, w13_weight_fp8
10661100
activation="silu",
10671101
experts_min=ep_shift,
10681102
experts_max=(num_experts + ep_shift - 1),
1103+
**kwargs
10691104
)
10701105
final_hidden_states_list.append(current_hidden_states)
10711106
final_hidden_states = torch.cat(final_hidden_states_list, dim=0)
@@ -1084,6 +1119,7 @@ def do_dynamic_moe_with_static_scaling(x, topk_ids, topk_weights, w13_weight_fp8
10841119
activation="silu",
10851120
experts_min=ep_shift,
10861121
experts_max=(num_experts + ep_shift - 1),
1122+
**kwargs
10871123
)
10881124
return final_hidden_states.view(-1, x.shape[1])
10891125

vllm/model_executor/layers/vllm_ext_patch.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
# ==-------------------------------------------------------------------------==
22
# VLLM-HPU-EXT PATCH Start
33
# ==-------------------------------------------------------------------------==
4+
import logging
5+
import os
46
import torch
57
from typing import Callable, Optional, Tuple
68
import habana_frameworks.torch as htorch
79

10+
logging.basicConfig(level=logging.INFO)
11+
812

913
class MoeFP8Matmul(torch.nn.Module):
1014
def __init__(
@@ -66,7 +70,11 @@ def get_dequant_weights_func(
6670

6771
class VllmMixtureOfExpertsOpFP8(torch.nn.Module):
6872
def __init__(
69-
self, num_experts: int, experts_min: int = 0, experts_max: int = 8
73+
self,
74+
num_experts: int,
75+
global_num_experts: int = 0,
76+
experts_min: int = 0,
77+
experts_max: int = 8,
7078
):
7179
super().__init__()
7280
self.w13_list = torch.nn.ModuleList(
@@ -75,10 +83,52 @@ def __init__(
7583
self.w2_list = torch.nn.ModuleList(
7684
[MoeFP8Matmul() for _ in range(num_experts)]
7785
)
86+
self.enable_moe_chunk = (
87+
os.environ.get("VLLM_SUPPORT_MOE_CHUNK", "false").lower() == "true"
88+
)
89+
self.chunk_size_list = [
90+
int(x)
91+
for x in os.environ.get(
92+
"PT_HPU_MOE_CHUNK", "64,128,512,1024,1536,2048,4096"
93+
).split(",")
94+
if x.strip()
95+
]
96+
self.token_boundary_list = [
97+
int(x)
98+
for x in os.environ.get(
99+
"PT_HPU_MOE_TOKEN_BOUNDARY", "64,128,1536,1736,2048,3072,4096"
100+
).split(",")
101+
if x.strip()
102+
]
103+
assert len(self.chunk_size_list) == len(self.token_boundary_list), (
104+
f"chunk_size_list({len(self.chunk_size_list)}) and "
105+
f"token_boundary_list({len(self.token_boundary_list)}) must be the same length"
106+
)
107+
logger = logging.getLogger()
108+
if self.enable_moe_chunk:
109+
logger.info("token_boundary_list is:%s",self.token_boundary_list)
110+
logger.info("chunk_size_list is:%s",self.chunk_size_list)
111+
78112
self.num_experts = num_experts
113+
self.global_num_experts = global_num_experts
79114
self.experts_min = experts_min
80115
self.experts_max = experts_max
81116

117+
def _get_extra_kwargs(self, tokens_num: int):
118+
if self.enable_moe_chunk:
119+
chunk_size = self.chunk_size_list[-1]
120+
for idx, threshold in enumerate(self.token_boundary_list):
121+
if tokens_num <= threshold:
122+
chunk_size = self.chunk_size_list[idx]
123+
break
124+
kwargs = {
125+
"chunk_size": chunk_size,
126+
"total_experts": self.global_num_experts,
127+
}
128+
else:
129+
kwargs = {}
130+
return kwargs
131+
82132
def forward(
83133
self,
84134
x,
@@ -89,6 +139,8 @@ def forward(
89139
max_expert = self.experts_max
90140
w13_list_slice = []
91141
w2_list_slice = []
142+
tokens_num, _ = x.shape
143+
kwargs = self._get_extra_kwargs(tokens_num)
92144
for j in range(self.num_experts):
93145
w13_list_slice.append(self.w13_list[j].get_dequant_weight())
94146
w2_list_slice.append(self.w2_list[j].get_dequant_weight())
@@ -103,6 +155,7 @@ def forward(
103155
activation="silu",
104156
experts_min=min_expert,
105157
experts_max=max_expert,
158+
**kwargs,
106159
)
107160
htorch.core.mark_step()
108161
return final_hidden_states

0 commit comments

Comments
 (0)