Skip to content

Commit a9e0156

Browse files
committed
reformat
1 parent 161ae03 commit a9e0156

File tree

8 files changed

+43
-28
lines changed

8 files changed

+43
-28
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,20 +70,27 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t
7070
)
7171
topk_weights.mul_(self.routed_scaling_factor)
7272
if self.num_fused_shared_experts > 0:
73-
pad_topk_ids = torch.arange(
74-
start=self.n_routed_experts - self.num_fused_shared_experts,
75-
end=self.n_routed_experts,
76-
step=1,
77-
dtype=topk_ids.dtype,
78-
device="cuda").view(1, self.num_fused_shared_experts).repeat(topk_ids.shape[0], 1)
79-
pad_topk_weights = torch.full((topk_weights.shape[0], self.num_fused_shared_experts),
80-
fill_value=1.0,
81-
device="cuda",
82-
dtype=topk_weights.dtype)
83-
73+
pad_topk_ids = (
74+
torch.arange(
75+
start=self.n_routed_experts - self.num_fused_shared_experts,
76+
end=self.n_routed_experts,
77+
step=1,
78+
dtype=topk_ids.dtype,
79+
device="cuda",
80+
)
81+
.view(1, self.num_fused_shared_experts)
82+
.repeat(topk_ids.shape[0], 1)
83+
)
84+
pad_topk_weights = torch.full(
85+
(topk_weights.shape[0], self.num_fused_shared_experts),
86+
fill_value=1.0,
87+
device="cuda",
88+
dtype=topk_weights.dtype,
89+
)
90+
8491
topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1)
8592
topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1)
86-
93+
8794
w1, w1_scale = self.w1
8895
w2, w2_scale = self.w2
8996
use_fp8_w8a8 = self.quant_method is not None

lightllm/common/fused_moe/moe_silu_and_mul.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import triton.language as tl
55
from .moe_silu_and_mul_config import MoeSiluAndMulKernelConfig
66

7+
78
@triton.jit
89
def _silu_and_mul_kernel_fast(
910
input_ptr,
@@ -34,12 +35,12 @@ def _silu_and_mul_kernel_fast(
3435
else:
3536
mask = None
3637
other = None
37-
38+
3839
for m_index in tl.range(m_start_index, m_end_index, num_stages=NUM_STAGES):
3940
gate_offsets = m_index * stride_input_m + n_offsets[None, :]
4041
up_offsets = m_index * stride_input_m + (n_offsets[None, :] + size_n)
4142
out_offsets = m_index * stride_output_m + n_offsets[None, :]
42-
43+
4344
up = tl.load(
4445
input_ptr + up_offsets,
4546
mask=mask,

lightllm/common/quantization/deepgemm_quant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,6 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
6565
)
6666

6767
if out is None:
68-
out = alloc_func((m, n), input_tensor.dtype, device=input_tensor.device)
68+
out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
6969
deep_gemm.gemm_fp8_fp8_bf16_nt([qinput_tensor, input_scale], [qweight.t(), weight_scale.t()], out)
7070
return out

lightllm/common/quantization/triton_quant/triton_quant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
4949
m = input_tensor.shape[0]
5050
n = qweight.shape[1]
5151
if out is None:
52-
out = alloc_func((m, n), input_tensor.dtype, device=input_tensor.device)
52+
out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
5353
w8a8_block_fp8_matmul(
5454
input_tensor_q,
5555
qweight,

lightllm/common/quantization/w8a8_quant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
137137
input_tensor, self.block_size, dtype=qweight.dtype, alloc_func=alloc_func
138138
)
139139
if out is None:
140-
out = alloc_func((m, n), input_tensor.dtype, device=input_tensor.device)
140+
out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
141141
if n % 128 != 0:
142142
w8a8_block_fp8_matmul(
143143
qinput_tensor,

lightllm/models/deepseek2/triton_kernel/rotary_emb_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def try_to_get_best_config(
3838
config = {"BLOCK_SEQ": 1, "NUM_STAGE": 1, "num_warps": 1, "num_stages": 1, "HEAD_PARALLEL_NUM": 1}
3939
else:
4040
config = {"BLOCK_SEQ": 16, "NUM_STAGE": 1, "num_warps": 1, "num_stages": 1, "HEAD_PARALLEL_NUM": 1}
41-
41+
4242
return config
4343

4444
@classmethod

test/kernel/deepseekv3_rotary_emb_tuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def tuning_configs(
213213
if __name__ == "__main__":
214214
torch.multiprocessing.set_start_method("spawn")
215215
from lightllm.utils.tuning_utils import mp_tuning
216-
216+
217217
# for deepseekv3 600B
218218

219219
for q_head_num in [128, 64, 32, 16, 8]:

test/kernel/fuse_moe_tuning.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,20 +99,27 @@ def test_kernel(
9999
topk_values, topk_ids = torch.topk(rnd_logics, topk, dim=1)
100100
if num_fused_shared_experts > 0:
101101
# 存在融合共享专家的时候,需要pad 共享专家对应的id 到topk_ids 中
102-
pad_topk_ids = torch.arange(
103-
start=expert_num - num_fused_shared_experts,
104-
end=expert_num,
105-
step=1,
106-
dtype=topk_ids.dtype,
107-
device="cuda").view(1, num_fused_shared_experts).repeat(topk_ids.shape[0], 1)
102+
pad_topk_ids = (
103+
torch.arange(
104+
start=expert_num - num_fused_shared_experts, end=expert_num, step=1, dtype=topk_ids.dtype, device="cuda"
105+
)
106+
.view(1, num_fused_shared_experts)
107+
.repeat(topk_ids.shape[0], 1)
108+
)
108109
topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1)
109110
topk_weights = torch.randn((m, topk + num_fused_shared_experts), device="cuda", dtype=dtype) / 10
110111

111-
expert_to_tokens = torch.empty((expert_num, (topk + num_fused_shared_experts) * m), dtype=torch.int32, device="cuda")
112-
expert_to_weights = torch.empty((expert_num, (topk + num_fused_shared_experts) * m), dtype=torch.float32, device="cuda")
112+
expert_to_tokens = torch.empty(
113+
(expert_num, (topk + num_fused_shared_experts) * m), dtype=torch.int32, device="cuda"
114+
)
115+
expert_to_weights = torch.empty(
116+
(expert_num, (topk + num_fused_shared_experts) * m), dtype=torch.float32, device="cuda"
117+
)
113118
moe_align(topk_ids=topk_ids, out=expert_to_tokens)
114119
expert_to_token_num = torch.empty((expert_num,), dtype=torch.int32, device="cuda")
115-
moe_align1(expert_to_tokens, topk_weights, expert_to_weights, expert_to_token_num, topk=topk + num_fused_shared_experts)
120+
moe_align1(
121+
expert_to_tokens, topk_weights, expert_to_weights, expert_to_token_num, topk=topk + num_fused_shared_experts
122+
)
116123

117124
out1 = torch.zeros((m * (topk + num_fused_shared_experts), 2 * n), dtype=torch.bfloat16, device="cuda")
118125
down_in = torch.zeros((m * (topk + num_fused_shared_experts), n), dtype=torch.bfloat16, device="cuda")

0 commit comments

Comments
 (0)