Skip to content

Commit 89da2d2

Browse files
author
wanghao7
committed
improve ep_gather
1 parent 631b6a8 commit 89da2d2

File tree

1 file changed

+38
-32
lines changed

1 file changed

+38
-32
lines changed

lightllm/common/fused_moe/deepep_scatter_gather.py

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -169,78 +169,84 @@ def ep_scatter(
169169

170170
@triton.jit
171171
def _fwd_kernel_ep_gather(
172-
total_token_num,
173172
input_tensor,
174173
input_tensor_stride0,
175174
input_tensor_stride1,
176-
recv_topk_ids,
177-
recv_topk_ids_stride0,
178-
recv_topk_ids_stride1,
175+
recv_topk,
176+
recv_topk_stride0,
177+
recv_topk_stride1,
179178
recv_topk_weight,
180179
recv_topk_weight_stride0,
181180
recv_topk_weight_stride1,
182181
input_index,
183182
input_index_stride0,
184183
input_index_stride1,
184+
expert_start_loc,
185185
output_tensor,
186186
output_tensor_stride0,
187187
output_tensor_stride1,
188-
topk_num: tl.constexpr,
188+
topk_col: tl.constexpr,
189189
BLOCK_D: tl.constexpr,
190+
HIDDEN_SIZE: tl.constexpr,
191+
HIDDEN_SIZE_PAD: tl.constexpr,
190192
):
191-
cur_block = tl.program_id(0)
192-
start_cur_token = tl.program_id(1)
193-
grid_num = tl.num_programs(1)
194-
195-
for cur_token in range(start_cur_token, total_token_num, grid_num):
196-
off_d = tl.arange(0, BLOCK_D)
197-
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
198-
for topk_index in range(0, topk_num):
199-
expert_id = tl.load(recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index)
200-
if expert_id >= 0:
201-
source_token_index = tl.load(input_index + cur_token * input_index_stride0 + topk_index)
202-
acc_weight = tl.load(recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index)
203-
tmp = tl.load(input_tensor + source_token_index * input_tensor_stride0 + cur_block * BLOCK_D + off_d)
204-
accumulator += tmp.to(tl.float32) * acc_weight
205-
206-
tl.store(
207-
output_tensor + cur_token * output_tensor_stride0 + cur_block * BLOCK_D + off_d,
208-
accumulator.to(output_tensor.dtype.element_ty),
209-
)
193+
token_id = tl.program_id(0)
194+
offset = tl.arange(0, HIDDEN_SIZE_PAD)
195+
mask = offset < HIDDEN_SIZE
196+
accumulator = tl.zeros([HIDDEN_SIZE_PAD], dtype=tl.float32)
197+
198+
for start_topk in range(0, topk_col):
199+
cur_expert = tl.load(recv_topk + token_id * recv_topk_stride0 + start_topk)
200+
if cur_expert >= 0:
201+
start_ = tl.load(expert_start_loc + cur_expert)
202+
dst = tl.load(input_index + token_id * input_index_stride0 + start_topk) + start_
203+
204+
weight = tl.load(recv_topk_weight + token_id * recv_topk_weight_stride0 + start_topk)
205+
tmp = tl.load(input_tensor + dst + offset)
206+
accumulator += tmp.to(tl.float32) * weight
207+
208+
tl.store(
209+
output_tensor + token_id * output_tensor_stride0 + offset,
210+
accumulator.to(output_tensor.dtype.element_ty),
211+
)
210212

211213

212214
@torch.no_grad()
213215
def ep_gather(
214216
input_tensor: torch.Tensor,
215-
recv_topk_ids: torch.Tensor,
217+
recv_topk: torch.Tensor,
216218
recv_topk_weight: torch.Tensor,
217219
input_index: torch.Tensor,
220+
expert_start_loc: torch.Tensor,
218221
output_tensor: torch.Tensor,
219222
):
220223
BLOCK_D = 128 # block size of quantization
221224
num_warps = 4
222225
num_tokens = output_tensor.shape[0]
223226
hidden_size = input_tensor.shape[1]
224-
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
225-
_fwd_kernel_ep_gather[grid](
226-
num_tokens,
227+
grid = min(num_tokens, 65535)
228+
229+
_fwd_kernel_ep_gather[(grid,)](
227230
input_tensor,
228231
input_tensor.stride(0),
229232
input_tensor.stride(1),
230-
recv_topk_ids,
231-
recv_topk_ids.stride(0),
232-
recv_topk_ids.stride(1),
233+
recv_topk,
234+
recv_topk.stride(0),
235+
recv_topk.stride(1),
233236
recv_topk_weight,
234237
recv_topk_weight.stride(0),
235238
recv_topk_weight.stride(1),
236239
input_index,
237240
input_index.stride(0),
238241
input_index.stride(1),
242+
expert_start_loc,
239243
output_tensor,
240244
output_tensor.stride(0),
241245
output_tensor.stride(1),
242-
topk_num=recv_topk_ids.shape[1],
246+
topk_col=recv_topk.shape[1],
243247
num_warps=num_warps,
244248
BLOCK_D=BLOCK_D,
249+
HIDDEN_SIZE=hidden_size,
250+
HIDDEN_SIZE_PAD = triton.next_power_of_2(hidden_size),
245251
)
246252
return

0 commit comments

Comments
 (0)