Skip to content

Commit 5a39dec

Browse files
authored
fix
1 parent 45a2930 commit 5a39dec

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ def grouped_matmul(
462462
out: torch.Tensor,
463463
mul_routed_weight: bool,
464464
use_fp8_w8a8: bool,
465+
alloc_tensor_func=torch.empty,
465466
reused_mblock_infos=None,
466467
run_config: Optional[dict] = None,
467468
):
@@ -525,8 +526,8 @@ def grouped_matmul(
525526
else:
526527
_m, _k = token_inputs.shape
527528
assert _k % block_size_k == 0
528-
input_scale = torch.empty((_m, _k // block_size_k), dtype=torch.float32, device=token_inputs.device)
529-
qinput_tensor = torch.empty((_m, _k), dtype=expert_weights.dtype, device=token_inputs.device)
529+
input_scale = alloc_tensor_func((_m, _k // block_size_k), dtype=torch.float32, device=token_inputs.device)
530+
qinput_tensor = alloc_tensor_func((_m, _k), dtype=expert_weights.dtype, device=token_inputs.device)
530531
per_token_group_quant_fp8(token_inputs, block_size_k, qinput_tensor, input_scale)
531532
token_inputs, token_input_scale = qinput_tensor, input_scale
532533

@@ -611,6 +612,7 @@ def fused_experts_impl(
611612
w2_scale: Optional[torch.Tensor] = None,
612613
a1_scale: Optional[torch.Tensor] = None,
613614
a2_scale: Optional[torch.Tensor] = None,
615+
alloc_tensor_func=torch.empty,
614616
run_config: Optional[dict] = None,
615617
):
616618
# Check constraints.
@@ -625,26 +627,29 @@ def fused_experts_impl(
625627
CHUNK_SIZE = FFN_MOE_CHUNK_SIZE
626628
topk_num = topk_ids.shape[1]
627629
M = min(num_tokens, CHUNK_SIZE)
628-
629-
cache = torch.empty(M * topk_num * max(N, w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype)
630-
intermediate_cache1 = cache[: M * topk_num * N].view(M, topk_num, N)
631-
intermediate_cache2 = torch.empty((M, topk_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype)
632-
intermediate_cache3 = cache[: M * topk_num * w2.shape[1]].view(M, topk_num, w2.shape[1])
630+
631+
intermediate_cache13_shared = alloc_tensor_func((M, topk_num, max(N, w2.shape[1])), device=hidden_states.device, dtype=hidden_states.dtype)
632+
intermediate_cache1 = intermediate_cache13_shared.view(-1)[:(M * topk_num * N)].view(M, topk_num, N)
633+
intermediate_cache2 = alloc_tensor_func(
634+
(M, topk_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype
635+
)
636+
intermediate_cache3 = intermediate_cache13_shared.view(-1)[:(M * topk_num * w2.shape[1])].view(M, topk_num, w2.shape[1])
633637

634638
if inplace:
635639
out_hidden_states = hidden_states
636640
else:
637-
out_hidden_states = torch.empty(hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype)
641+
out_hidden_states = alloc_tensor_func(
642+
hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype
643+
)
638644

639645
for chunk in range(triton.cdiv(num_tokens, CHUNK_SIZE)):
640646
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, num_tokens))
641647
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
642648
tokens_in_chunk, _ = curr_hidden_states.shape
643649

644-
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
645-
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
646-
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
647-
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
650+
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
651+
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
652+
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
648653

649654
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
650655
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
@@ -668,6 +673,7 @@ def fused_experts_impl(
668673
out=intermediate_cache1.view(-1, N),
669674
mul_routed_weight=False,
670675
use_fp8_w8a8=use_fp8_w8a8,
676+
alloc_tensor_func=alloc_tensor_func,
671677
run_config=run_config,
672678
)
673679

@@ -686,6 +692,7 @@ def fused_experts_impl(
686692
out=intermediate_cache3.view(-1, w2.shape[1]),
687693
mul_routed_weight=True,
688694
use_fp8_w8a8=use_fp8_w8a8,
695+
alloc_tensor_func=alloc_tensor_func,
689696
reused_mblock_infos=reused_mblock_infos,
690697
run_config=run_config,
691698
)

0 commit comments

Comments
 (0)