Skip to content

Commit 2926532

Browse files
authored
Torch.ops (#1001)
1 parent 762ae34 commit 2926532

File tree

4 files changed

+230
-35
lines changed

4 files changed

+230
-35
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t
7676
w2, w2_scale = self.w2
7777
use_fp8_w8a8 = self.quant_method is not None
7878

79-
from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl
79+
from lightllm.common.fused_moe.grouped_fused_moe import fused_experts
8080

81-
fused_experts_impl(
81+
fused_experts(
8282
hidden_states=input_tensor,
8383
w1=w1,
8484
w2=w2,

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 166 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@
3434
from .moe_silu_and_mul import silu_and_mul_fwd
3535
from .moe_sum_reduce import moe_sum_reduce
3636
from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8
37+
from lightllm.utils.torch_ops_utils import direct_register_custom_op
3738

38-
FFN_MOE_CHUNK_SIZE = 8 * 1024
39+
FFN_MOE_CHUNK_SIZE = 32 * 1024
3940

4041
logger = init_logger(__name__)
4142

@@ -355,7 +356,7 @@ def grouped_matmul_kernel(
355356
tile_n_idx = pid_n
356357

357358
# get the gemm size of the current problem
358-
cur_m = tl.load(expert_to_token_num + expert_id, eviction_policy="evict_last")
359+
cur_m = tl.load(expert_to_token_num + expert_id)
359360

360361
# do regular gemm here
361362
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
@@ -461,9 +462,8 @@ def grouped_matmul(
461462
out: torch.Tensor,
462463
mul_routed_weight: bool,
463464
use_fp8_w8a8: bool,
464-
alloc_tensor_func=torch.empty,
465465
reused_mblock_infos=None,
466-
**run_config,
466+
run_config: Optional[dict] = None,
467467
):
468468
"""
469469
token_num_mul_topk_num is int equal token_num * topk_num,
@@ -493,7 +493,8 @@ def grouped_matmul(
493493
if expert_to_weights_scale.ndim == 3:
494494
block_size_n = expert_weights.shape[1] // expert_to_weights_scale.shape[1]
495495
block_size_k = expert_weights.shape[2] // expert_to_weights_scale.shape[2]
496-
if not run_config:
496+
497+
if run_config is None:
497498
run_config = MoeGroupedGemmKernelConfig.try_to_get_best_config(
498499
M=token_inputs.shape[0],
499500
N=n,
@@ -525,8 +526,8 @@ def grouped_matmul(
525526
else:
526527
_m, _k = token_inputs.shape
527528
assert _k % block_size_k == 0
528-
input_scale = alloc_tensor_func((_m, _k // block_size_k), dtype=torch.float32, device=token_inputs.device)
529-
qinput_tensor = alloc_tensor_func((_m, _k), dtype=expert_weights.dtype, device=token_inputs.device)
529+
input_scale = torch.empty((_m, _k // block_size_k), dtype=torch.float32, device=token_inputs.device)
530+
qinput_tensor = torch.empty((_m, _k), dtype=expert_weights.dtype, device=token_inputs.device)
530531
per_token_group_quant_fp8(token_inputs, block_size_k, qinput_tensor, input_scale)
531532
token_inputs, token_input_scale = qinput_tensor, input_scale
532533

@@ -611,8 +612,7 @@ def fused_experts_impl(
611612
w2_scale: Optional[torch.Tensor] = None,
612613
a1_scale: Optional[torch.Tensor] = None,
613614
a2_scale: Optional[torch.Tensor] = None,
614-
alloc_tensor_func=torch.empty,
615-
**run_config,
615+
run_config: Optional[dict] = None,
616616
):
617617
# Check constraints.
618618
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
@@ -627,18 +627,16 @@ def fused_experts_impl(
627627
topk_num = topk_ids.shape[1]
628628
M = min(num_tokens, CHUNK_SIZE)
629629

630-
intermediate_cache1 = alloc_tensor_func((M, topk_num, N), device=hidden_states.device, dtype=hidden_states.dtype)
631-
intermediate_cache2 = alloc_tensor_func(
632-
(M, topk_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype
633-
)
634-
intermediate_cache3 = alloc_tensor_func(
635-
(M, topk_num, w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype
636-
)
630+
cache = torch.empty(M*topk_num*max(N, w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype)
631+
intermediate_cache1 = cache[:M * topk_num * N].view(M, topk_num, N)
632+
intermediate_cache2 = torch.empty((M, topk_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype)
633+
intermediate_cache3 = cache[:M * topk_num * w2.shape[1]].view(M, topk_num, w2.shape[1])
634+
637635

638636
if inplace:
639637
out_hidden_states = hidden_states
640638
else:
641-
out_hidden_states = alloc_tensor_func(
639+
out_hidden_states = torch.empty(
642640
hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype
643641
)
644642

@@ -647,9 +645,10 @@ def fused_experts_impl(
647645
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
648646
tokens_in_chunk, _ = curr_hidden_states.shape
649647

650-
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
651-
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
652-
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
648+
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
649+
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
650+
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
651+
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
653652

654653
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
655654
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
@@ -673,8 +672,7 @@ def fused_experts_impl(
673672
out=intermediate_cache1.view(-1, N),
674673
mul_routed_weight=False,
675674
use_fp8_w8a8=use_fp8_w8a8,
676-
alloc_tensor_func=alloc_tensor_func,
677-
**run_config,
675+
run_config=run_config,
678676
)
679677

680678
silu_and_mul_fwd(intermediate_cache1.view(-1, N), intermediate_cache2.view(-1, N // 2))
@@ -692,12 +690,156 @@ def fused_experts_impl(
692690
out=intermediate_cache3.view(-1, w2.shape[1]),
693691
mul_routed_weight=True,
694692
use_fp8_w8a8=use_fp8_w8a8,
695-
alloc_tensor_func=alloc_tensor_func,
696693
reused_mblock_infos=reused_mblock_infos,
697-
**run_config,
694+
run_config=run_config,
698695
)
699696

700697
moe_sum_reduce(
701698
intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx]
702699
)
703700
return out_hidden_states
701+
702+
703+
def inplace_fused_experts_impl(
704+
hidden_states: torch.Tensor,
705+
w1: torch.Tensor,
706+
w2: torch.Tensor,
707+
topk_weights: torch.Tensor,
708+
topk_ids: torch.Tensor,
709+
use_fp8_w8a8: bool = False,
710+
use_int8_w8a16: bool = False,
711+
w1_scale: Optional[torch.Tensor] = None,
712+
w2_scale: Optional[torch.Tensor] = None,
713+
a1_scale: Optional[torch.Tensor] = None,
714+
a2_scale: Optional[torch.Tensor] = None,
715+
)-> None:
716+
fused_experts_impl(
717+
hidden_states,
718+
w1,
719+
w2,
720+
topk_weights,
721+
topk_ids,
722+
True,
723+
use_fp8_w8a8,
724+
use_int8_w8a16,
725+
w1_scale,
726+
w2_scale,
727+
a1_scale,
728+
a2_scale,
729+
)
730+
731+
def inplace_fused_experts_impl_fake(
732+
hidden_states: torch.Tensor,
733+
w1: torch.Tensor,
734+
w2: torch.Tensor,
735+
topk_weights: torch.Tensor,
736+
topk_ids: torch.Tensor,
737+
use_fp8_w8a8: bool = False,
738+
use_int8_w8a16: bool = False,
739+
w1_scale: Optional[torch.Tensor] = None,
740+
w2_scale: Optional[torch.Tensor] = None,
741+
a1_scale: Optional[torch.Tensor] = None,
742+
a2_scale: Optional[torch.Tensor] = None,
743+
)-> None:
744+
pass
745+
746+
direct_register_custom_op(
747+
"inplace_fused_experts_impl",
748+
inplace_fused_experts_impl,
749+
["hidden_states"],
750+
inplace_fused_experts_impl_fake,
751+
)
752+
753+
def outplace_fused_experts_impl(
754+
hidden_states: torch.Tensor,
755+
w1: torch.Tensor,
756+
w2: torch.Tensor,
757+
topk_weights: torch.Tensor,
758+
topk_ids: torch.Tensor,
759+
use_fp8_w8a8: bool = False,
760+
use_int8_w8a16: bool = False,
761+
w1_scale: Optional[torch.Tensor] = None,
762+
w2_scale: Optional[torch.Tensor] = None,
763+
a1_scale: Optional[torch.Tensor] = None,
764+
a2_scale: Optional[torch.Tensor] = None,
765+
)-> None:
766+
return fused_experts_impl(
767+
hidden_states,
768+
w1,
769+
w2,
770+
topk_weights,
771+
topk_ids,
772+
False,
773+
use_fp8_w8a8,
774+
use_int8_w8a16,
775+
w1_scale,
776+
w2_scale,
777+
a1_scale,
778+
a2_scale,
779+
)
780+
781+
def outplace_fused_experts_impl_fake(
782+
hidden_states: torch.Tensor,
783+
w1: torch.Tensor,
784+
w2: torch.Tensor,
785+
topk_weights: torch.Tensor,
786+
topk_ids: torch.Tensor,
787+
use_fp8_w8a8: bool = False,
788+
use_int8_w8a16: bool = False,
789+
w1_scale: Optional[torch.Tensor] = None,
790+
w2_scale: Optional[torch.Tensor] = None,
791+
a1_scale: Optional[torch.Tensor] = None,
792+
a2_scale: Optional[torch.Tensor] = None,
793+
)-> None:
794+
return torch.empty_like(hidden_states)
795+
796+
direct_register_custom_op(
797+
"outplace_fused_experts_impl",
798+
outplace_fused_experts_impl,
799+
[],
800+
outplace_fused_experts_impl_fake,
801+
)
802+
803+
def fused_experts(
804+
hidden_states: torch.Tensor,
805+
w1: torch.Tensor,
806+
w2: torch.Tensor,
807+
topk_weights: torch.Tensor,
808+
topk_ids: torch.Tensor,
809+
inplace: bool = False,
810+
use_fp8_w8a8: bool = False,
811+
use_int8_w8a16: bool = False,
812+
w1_scale: Optional[torch.Tensor] = None,
813+
w2_scale: Optional[torch.Tensor] = None,
814+
a1_scale: Optional[torch.Tensor] = None,
815+
a2_scale: Optional[torch.Tensor] = None,
816+
):
817+
if inplace:
818+
torch.ops.lightllm.inplace_fused_experts_impl(
819+
hidden_states,
820+
w1,
821+
w2,
822+
topk_weights,
823+
topk_ids,
824+
use_fp8_w8a8,
825+
use_int8_w8a16,
826+
w1_scale,
827+
w2_scale,
828+
a1_scale,
829+
a2_scale,
830+
)
831+
return hidden_states
832+
else:
833+
return torch.ops.lightllm.outplace_fused_experts_impl(
834+
hidden_states,
835+
w1,
836+
w2,
837+
topk_weights,
838+
topk_ids,
839+
use_fp8_w8a8,
840+
use_int8_w8a16,
841+
w1_scale,
842+
w2_scale,
843+
a1_scale,
844+
a2_scale,
845+
)

lightllm/utils/torch_ops_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils.py
2+
from torch.library import Library
3+
from typing import (
4+
Any,
5+
Callable,
6+
Dict,
7+
Generic,
8+
List,
9+
Optional,
10+
Protocol,
11+
Set,
12+
Tuple,
13+
TypeVar,
14+
Union,
15+
)
16+
import torch
17+
18+
lightllm_lib = Library("lightllm", "FRAGMENT") # noqa
19+
20+
21+
# Some backends use pytorch version < 2.4.0 which doesn't
22+
# support `torch.library.custom_op`.
23+
def supports_custom_op() -> bool:
24+
return hasattr(torch.library, "custom_op")
25+
26+
27+
def direct_register_custom_op(
28+
op_name: str,
29+
op_func: Callable,
30+
mutates_args: List[str],
31+
fake_impl: Optional[Callable] = None,
32+
target_lib: Optional[Library] = None,
33+
):
34+
"""
35+
`torch.library.custom_op` can have significant overhead because it
36+
needs to consider complicated dispatching logic. This function
37+
directly registers a custom op and dispatches it to the CUDA backend.
38+
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
39+
for more details.
40+
"""
41+
import torch.library
42+
43+
if hasattr(torch.library, "infer_schema"):
44+
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
45+
else:
46+
# for pytorch 2.4
47+
import torch._custom_op.impl
48+
49+
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
50+
51+
my_lib = target_lib or lightllm_lib
52+
my_lib.define(op_name + schema_str)
53+
my_lib.impl(op_name, op_func, "CUDA")
54+
if fake_impl is not None:
55+
my_lib._register_fake(op_name, fake_impl)

0 commit comments

Comments
 (0)