Skip to content

Commit b6ae861

Browse files
committed
run linter
Signed-off-by: Bill Nell <[email protected]>
1 parent 918e62b commit b6ae861

File tree

14 files changed

+471
-1756
lines changed

14 files changed

+471
-1756
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 76 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
import torch
4-
import triton
5-
import triton.language as tl
3+
from dataclasses import dataclass
64

75
import pytest
8-
from dataclasses import dataclass
6+
import torch
7+
import triton.language as tl
98

109
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
11-
invoke_moe_batched_triton_kernel,
12-
invoke_batched_silu_and_mul)
10+
invoke_batched_silu_and_mul, invoke_moe_batched_triton_kernel)
1311

1412

1513
@dataclass
@@ -20,25 +18,36 @@ class BatchedMMConfig:
2018
K: int
2119
N: int
2220

21+
2322
@dataclass
2423
class BatchedMMTensors:
2524
A: torch.Tensor # [E, max_tokens, K]
2625
B: torch.Tensor # [E, K, N] - column major
2726
C: torch.Tensor # [E, max_tokens, N]
28-
num_expert_tokens: torch.Tensor # [E]
27+
num_expert_tokens: torch.Tensor # [E]
2928

3029
@staticmethod
3130
def make_tensors(config: BatchedMMConfig):
32-
A = torch.randn((config.num_experts, config.max_tokens_per_expert, config.K), device="cuda", dtype=config.dtype) / 50.0
33-
B = torch.randn((config.num_experts, config.N, config.K), device="cuda", dtype=config.dtype) / 50.0
34-
C = torch.zeros((config.num_experts, config.max_tokens_per_expert, config.N), device="cuda", dtype=config.dtype)
35-
num_expert_tokens=torch.randint(low = 0, high = config.max_tokens_per_expert, size=(config.num_experts,), device="cuda", dtype=torch.int32)
36-
return BatchedMMTensors(A,B,C, num_expert_tokens)
37-
38-
39-
def ref_impl(A: torch.Tensor,
40-
B: torch.Tensor,
41-
C: torch.Tensor,
31+
A = torch.randn(
32+
(config.num_experts, config.max_tokens_per_expert, config.K),
33+
device="cuda",
34+
dtype=config.dtype) / 50.0
35+
B = torch.randn((config.num_experts, config.N, config.K),
36+
device="cuda",
37+
dtype=config.dtype) / 50.0
38+
C = torch.zeros(
39+
(config.num_experts, config.max_tokens_per_expert, config.N),
40+
device="cuda",
41+
dtype=config.dtype)
42+
num_expert_tokens = torch.randint(low=0,
43+
high=config.max_tokens_per_expert,
44+
size=(config.num_experts, ),
45+
device="cuda",
46+
dtype=torch.int32)
47+
return BatchedMMTensors(A, B, C, num_expert_tokens)
48+
49+
50+
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
4251
num_expert_tokens: torch.Tensor) -> torch.Tensor:
4352

4453
num_expert_tokens_cpu = num_expert_tokens.clone()
@@ -49,49 +58,50 @@ def ref_impl(A: torch.Tensor,
4958
num_tokens = num_expert_tokens_cpu[e]
5059
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
5160

52-
5361
return C
5462

63+
5564
@pytest.mark.parametrize("num_experts", [16, 32])
5665
@pytest.mark.parametrize("max_tokens_per_expert", [512])
5766
@pytest.mark.parametrize("K", [256])
5867
@pytest.mark.parametrize("N", [512])
5968
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
60-
def test_batched_mm(num_experts: int,
61-
max_tokens_per_expert: int,
62-
K: int,
63-
N: int,
64-
dtype: torch.dtype):
69+
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
70+
N: int, dtype: torch.dtype):
6571

6672
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
6773
tensors = BatchedMMTensors.make_tensors(config)
6874

6975
test_output = tensors.C
7076
ref_output = test_output.clone()
7177

72-
73-
compute_tl_dtype = {torch.float16 : tl.float16,
74-
torch.bfloat16 : tl.bfloat16,
75-
torch.float32 : tl.float32}[test_output.dtype]
76-
invoke_moe_batched_triton_kernel(tensors.A,
77-
tensors.B,
78-
test_output,
79-
tensors.num_expert_tokens,
80-
compute_tl_dtype,
81-
# Quantization data
82-
None,
83-
None,
84-
None,
85-
# Quantization schemes
86-
False,
87-
False,
88-
False,
89-
config = {"BLOCK_SIZE_M": 16,
90-
"BLOCK_SIZE_N": 16,
91-
"BLOCK_SIZE_K": 16})
92-
93-
94-
ref_output = ref_impl(tensors.A, tensors.B, ref_output, tensors.num_expert_tokens)
78+
compute_tl_dtype = {
79+
torch.float16: tl.float16,
80+
torch.bfloat16: tl.bfloat16,
81+
torch.float32: tl.float32
82+
}[test_output.dtype]
83+
invoke_moe_batched_triton_kernel(
84+
tensors.A,
85+
tensors.B,
86+
test_output,
87+
tensors.num_expert_tokens,
88+
compute_tl_dtype,
89+
# Quantization data
90+
None,
91+
None,
92+
None,
93+
# Quantization schemes
94+
False,
95+
False,
96+
False,
97+
config={
98+
"BLOCK_SIZE_M": 16,
99+
"BLOCK_SIZE_N": 16,
100+
"BLOCK_SIZE_K": 16
101+
})
102+
103+
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
104+
tensors.num_expert_tokens)
95105
#torch.cuda.synchronize()
96106
#print (f"ref output {ref_output}")
97107
#print (f"test output {test_output}")
@@ -106,6 +116,7 @@ class BatchedSiluMulConfig:
106116
max_tokens_per_expert: int
107117
D: int
108118

119+
109120
@dataclass
110121
class BatchedSiluMulTensors:
111122
input: torch.Tensor
@@ -114,16 +125,24 @@ class BatchedSiluMulTensors:
114125

115126
@staticmethod
116127
def make_tensors(config: BatchedSiluMulConfig):
117-
input = torch.randn((config.num_experts, config.max_tokens_per_expert, config.D * 2), device="cuda", dtype=config.dtype) / 50.0
118-
output = torch.zeros((config.num_experts, config.max_tokens_per_expert, config.D), device="cuda", dtype=config.dtype)
119-
num_expert_tokens=torch.randint(low = 0, high = config.max_tokens_per_expert, size=(config.num_experts,), device="cuda", dtype=torch.int32)
128+
input = torch.randn(
129+
(config.num_experts, config.max_tokens_per_expert, config.D * 2),
130+
device="cuda",
131+
dtype=config.dtype) / 50.0
132+
output = torch.zeros(
133+
(config.num_experts, config.max_tokens_per_expert, config.D),
134+
device="cuda",
135+
dtype=config.dtype)
136+
num_expert_tokens = torch.randint(low=0,
137+
high=config.max_tokens_per_expert,
138+
size=(config.num_experts, ),
139+
device="cuda",
140+
dtype=torch.int32)
120141
return BatchedSiluMulTensors(input, output, num_expert_tokens)
121142

122143

123-
def ref_batched_silu_mul(
124-
output: torch.Tensor,
125-
input: torch.Tensor,
126-
num_expert_tokens: torch.Tensor) -> torch.Tensor:
144+
def ref_batched_silu_mul(output: torch.Tensor, input: torch.Tensor,
145+
num_expert_tokens: torch.Tensor) -> torch.Tensor:
127146

128147
num_expert_tokens_cpu = num_expert_tokens.clone()
129148
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
@@ -140,10 +159,8 @@ def ref_batched_silu_mul(
140159
@pytest.mark.parametrize("max_tokens_per_expert", [128])
141160
@pytest.mark.parametrize("D", [128, 256])
142161
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
143-
def test_batched_silu_mul(num_experts: int,
144-
max_tokens_per_expert: int,
145-
D: int,
146-
dtype: torch.dtype):
162+
def test_batched_silu_mul(num_experts: int, max_tokens_per_expert: int, D: int,
163+
dtype: torch.dtype):
147164

148165
config = BatchedSiluMulConfig(dtype, num_experts, max_tokens_per_expert, D)
149166
tensors = BatchedSiluMulTensors.make_tensors(config)
@@ -153,6 +170,7 @@ def test_batched_silu_mul(num_experts: int,
153170

154171
ref_batched_silu_mul(ref_out, tensors.input, tensors.expert_num_tokens)
155172

156-
invoke_batched_silu_and_mul(test_out, tensors.input, tensors.expert_num_tokens)
173+
invoke_batched_silu_and_mul(test_out, tensors.input,
174+
tensors.expert_num_tokens)
157175

158176
torch.testing.assert_close(test_out, ref_out)

tests/kernels/moe/test_moe.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
torch_moe_single)
1616
from vllm.config import VllmConfig, set_current_vllm_config
1717
from vllm.model_executor.layers.fused_moe import fused_moe
18-
from vllm.model_executor.layers.fused_moe.fused_moe import (
19-
fused_topk, moe_align_block_size)
18+
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
2019
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
2120
fused_moe as iterative_moe)
2221
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
@@ -26,7 +25,6 @@
2625
from vllm.model_executor.models.mixtral import MixtralMoE
2726
from vllm.platforms import current_platform
2827
from vllm.scalar_type import scalar_types
29-
from vllm.model_executor.layers.activation import SiluAndMul
3028

3129
NUM_EXPERTS = [8, 64]
3230
EP_SIZE = [1, 4]

tests/kernels/quantization/test_block_fp8.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from vllm.model_executor.layers.activation import SiluAndMul
1212
from vllm.model_executor.layers.fused_moe import fused_moe
1313
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
14-
_valid_deep_gemm_shape, deep_gemm_moe_fp8, modular_deep_gemm_fused_moe_fp8)
14+
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
1515
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
1616
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
1717
moe_align_block_size)
@@ -437,8 +437,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
437437

438438
topk_weights, topk_ids = fused_topk(a, score.float(), topk, False)
439439

440-
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights,
441-
topk_ids)
440+
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
442441

443442
#print(f"{out.sum()=}")
444443
#print(f"{ref_out.sum()=}")

0 commit comments

Comments
 (0)