Skip to content

Commit e2b029c

Browse files
committed
update gather
1 parent 631b6a8 commit e2b029c

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

lightllm/common/fused_moe/deepep_scatter_gather.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -185,27 +185,31 @@ def _fwd_kernel_ep_gather(
185185
output_tensor,
186186
output_tensor_stride0,
187187
output_tensor_stride1,
188+
HIDDEN_SIZE: tl.constexpr,
189+
HIDDEN_SIZE_PAD: tl.constexpr,
188190
topk_num: tl.constexpr,
189191
BLOCK_D: tl.constexpr,
190192
):
191-
cur_block = tl.program_id(0)
192-
start_cur_token = tl.program_id(1)
193-
grid_num = tl.num_programs(1)
193+
start_cur_token = tl.program_id(0)
194+
grid_num = tl.num_programs(0)
195+
196+
offset_d = tl.arange(0, HIDDEN_SIZE_PAD)
197+
mask = offset_d < HIDDEN_SIZE
194198

195199
for cur_token in range(start_cur_token, total_token_num, grid_num):
196-
off_d = tl.arange(0, BLOCK_D)
197-
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
200+
accumulator = tl.zeros([HIDDEN_SIZE_PAD], dtype=tl.float32)
198201
for topk_index in range(0, topk_num):
199202
expert_id = tl.load(recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index)
200203
if expert_id >= 0:
201204
source_token_index = tl.load(input_index + cur_token * input_index_stride0 + topk_index)
202205
acc_weight = tl.load(recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index)
203-
tmp = tl.load(input_tensor + source_token_index * input_tensor_stride0 + cur_block * BLOCK_D + off_d)
206+
tmp = tl.load(input_tensor + source_token_index * input_tensor_stride0 + offset_d, mask=mask)
204207
accumulator += tmp.to(tl.float32) * acc_weight
205208

206209
tl.store(
207-
output_tensor + cur_token * output_tensor_stride0 + cur_block * BLOCK_D + off_d,
210+
output_tensor + cur_token * output_tensor_stride0 + offset_d,
208211
accumulator.to(output_tensor.dtype.element_ty),
212+
mask=mask,
209213
)
210214

211215

@@ -221,7 +225,7 @@ def ep_gather(
221225
num_warps = 4
222226
num_tokens = output_tensor.shape[0]
223227
hidden_size = input_tensor.shape[1]
224-
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
228+
grid = (min(num_tokens, 32768),)
225229
_fwd_kernel_ep_gather[grid](
226230
num_tokens,
227231
input_tensor,
@@ -240,6 +244,8 @@ def ep_gather(
240244
output_tensor.stride(0),
241245
output_tensor.stride(1),
242246
topk_num=recv_topk_ids.shape[1],
247+
HIDDEN_SIZE=hidden_size,
248+
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
243249
num_warps=num_warps,
244250
BLOCK_D=BLOCK_D,
245251
)

0 commit comments

Comments
 (0)