Skip to content

Commit 6430851

Browse files
committed
reformat
1 parent e5dbfb5 commit 6430851

File tree

3 files changed

+19
-19
lines changed

3 files changed

+19
-19
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def grouped_matmul(
493493
if expert_to_weights_scale.ndim == 3:
494494
block_size_n = expert_weights.shape[1] // expert_to_weights_scale.shape[1]
495495
block_size_k = expert_weights.shape[2] // expert_to_weights_scale.shape[2]
496-
496+
497497
if run_config is None:
498498
run_config = MoeGroupedGemmKernelConfig.try_to_get_best_config(
499499
M=token_inputs.shape[0],
@@ -626,18 +626,15 @@ def fused_experts_impl(
626626
topk_num = topk_ids.shape[1]
627627
M = min(num_tokens, CHUNK_SIZE)
628628

629-
cache = torch.empty(M*topk_num*max(N, w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype)
630-
intermediate_cache1 = cache[:M * topk_num * N].view(M, topk_num, N)
629+
cache = torch.empty(M * topk_num * max(N, w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype)
630+
intermediate_cache1 = cache[: M * topk_num * N].view(M, topk_num, N)
631631
intermediate_cache2 = torch.empty((M, topk_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype)
632-
intermediate_cache3 = cache[:M * topk_num * w2.shape[1]].view(M, topk_num, w2.shape[1])
633-
632+
intermediate_cache3 = cache[: M * topk_num * w2.shape[1]].view(M, topk_num, w2.shape[1])
634633

635634
if inplace:
636635
out_hidden_states = hidden_states
637636
else:
638-
out_hidden_states = torch.empty(
639-
hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype
640-
)
637+
out_hidden_states = torch.empty(hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype)
641638

642639
for chunk in range(triton.cdiv(num_tokens, CHUNK_SIZE)):
643640
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, num_tokens))
@@ -711,7 +708,7 @@ def inplace_fused_experts_impl(
711708
w2_scale: Optional[torch.Tensor] = None,
712709
a1_scale: Optional[torch.Tensor] = None,
713710
a2_scale: Optional[torch.Tensor] = None,
714-
)-> None:
711+
) -> None:
715712
fused_experts_impl(
716713
hidden_states,
717714
w1,
@@ -727,6 +724,7 @@ def inplace_fused_experts_impl(
727724
a2_scale,
728725
)
729726

727+
730728
def inplace_fused_experts_impl_fake(
731729
hidden_states: torch.Tensor,
732730
w1: torch.Tensor,
@@ -739,16 +737,18 @@ def inplace_fused_experts_impl_fake(
739737
w2_scale: Optional[torch.Tensor] = None,
740738
a1_scale: Optional[torch.Tensor] = None,
741739
a2_scale: Optional[torch.Tensor] = None,
742-
)-> None:
740+
) -> None:
743741
pass
744742

743+
745744
direct_register_custom_op(
746745
"inplace_fused_experts_impl",
747746
inplace_fused_experts_impl,
748747
["hidden_states"],
749748
inplace_fused_experts_impl_fake,
750749
)
751750

751+
752752
def outplace_fused_experts_impl(
753753
hidden_states: torch.Tensor,
754754
w1: torch.Tensor,
@@ -761,7 +761,7 @@ def outplace_fused_experts_impl(
761761
w2_scale: Optional[torch.Tensor] = None,
762762
a1_scale: Optional[torch.Tensor] = None,
763763
a2_scale: Optional[torch.Tensor] = None,
764-
)-> None:
764+
) -> None:
765765
return fused_experts_impl(
766766
hidden_states,
767767
w1,
@@ -777,6 +777,7 @@ def outplace_fused_experts_impl(
777777
a2_scale,
778778
)
779779

780+
780781
def outplace_fused_experts_impl_fake(
781782
hidden_states: torch.Tensor,
782783
w1: torch.Tensor,
@@ -789,16 +790,18 @@ def outplace_fused_experts_impl_fake(
789790
w2_scale: Optional[torch.Tensor] = None,
790791
a1_scale: Optional[torch.Tensor] = None,
791792
a2_scale: Optional[torch.Tensor] = None,
792-
)-> None:
793+
) -> None:
793794
return torch.empty_like(hidden_states)
794795

796+
795797
direct_register_custom_op(
796798
"outplace_fused_experts_impl",
797799
outplace_fused_experts_impl,
798800
[],
799801
outplace_fused_experts_impl_fake,
800802
)
801803

804+
802805
def fused_experts(
803806
hidden_states: torch.Tensor,
804807
w1: torch.Tensor,
@@ -825,7 +828,7 @@ def fused_experts(
825828
w1_scale,
826829
w2_scale,
827830
a1_scale,
828-
a2_scale,
831+
a2_scale,
829832
)
830833
return hidden_states
831834
else:

lightllm/utils/torch_ops_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/utils.py
22
from torch.library import Library
3+
34
from typing import (
45
Any,
56
Callable,
@@ -52,4 +53,4 @@ def direct_register_custom_op(
5253
my_lib.define(op_name + schema_str)
5354
my_lib.impl(op_name, op_func, "CUDA")
5455
if fake_impl is not None:
55-
my_lib._register_fake(op_name, fake_impl)
56+
my_lib._register_fake(op_name, fake_impl)

test/kernel/fuse_moe_tuning.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,11 +273,7 @@ def get_test_configs(split_id, split_count):
273273
4,
274274
8,
275275
]:
276-
for BLOCK_SIZE_M in [
277-
32,
278-
64,
279-
128
280-
]:
276+
for BLOCK_SIZE_M in [32, 64, 128]:
281277
for BLOCK_SIZE_N in [32, 64, 128]:
282278
for BLOCK_SIZE_K in [32, 64, 128]:
283279
t_config = {

0 commit comments

Comments
 (0)