Skip to content

Commit d2862c0

Browse files
committed
lint
Signed-off-by: Bill Nell <[email protected]>
1 parent 3e2cf4b commit d2862c0

File tree

17 files changed

+123
-162
lines changed

17 files changed

+123
-162
lines changed

csrc/activation_kernels.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
7070
int64_t num_tokens = input.numel() / input.size(-1); \
7171
dim3 grid(num_tokens); \
7272
dim3 block(std::min(d, 1024)); \
73-
if (num_tokens == 0) { return; } \
73+
if (num_tokens == 0) { \
74+
return; \
75+
} \
7476
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
7577
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
7678
VLLM_DISPATCH_FLOATING_TYPES( \

csrc/dispatch_utils.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,18 @@
6666
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
6767

6868
#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
69-
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
70-
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
71-
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
72-
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
73-
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
74-
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
75-
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
69+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
70+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
71+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
72+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
73+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
74+
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
75+
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
7676
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
7777

7878
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
7979
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
8080

8181
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
82-
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
82+
AT_DISPATCH_SWITCH( \
83+
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))

examples/offline_inference/data_parallel.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from time import sleep
3232

3333
from vllm import LLM, SamplingParams
34-
from vllm.config import CompilationConfig
3534
from vllm.utils import get_open_port
3635

3736

@@ -116,20 +115,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
116115
max_tokens=[16, 20][global_dp_rank % 2])
117116

118117
# Create an LLM.
119-
cconfig = CompilationConfig(
120-
level=3,
121-
#cudagraph_capture_sizes=[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208],
122-
#cudagraph_capture_sizes=[512,256,1],
123-
#cudagraph_capture_sizes=[192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1]
124-
#cudagraph_capture_sizes=[128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1]
118+
llm = LLM(
119+
model=model,
120+
tensor_parallel_size=GPUs_per_dp_rank,
121+
enforce_eager=enforce_eager,
122+
enable_expert_parallel=True,
123+
trust_remote_code=trust_remote_code,
125124
)
126-
llm = LLM(model=model,
127-
tensor_parallel_size=GPUs_per_dp_rank,
128-
enforce_eager=enforce_eager,
129-
enable_expert_parallel=True,
130-
compilation_config=cconfig,
131-
trust_remote_code=trust_remote_code,
132-
)
133125
outputs = llm.generate(prompts, sampling_params)
134126
# Print the outputs.
135127
for i, output in enumerate(outputs):
@@ -172,7 +164,8 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
172164
proc = Process(target=main,
173165
args=(args.model, dp_size, local_dp_rank,
174166
global_dp_rank, dp_master_ip, dp_master_port,
175-
tp_size, args.enforce_eager, args.trust_remote_code))
167+
tp_size, args.enforce_eager,
168+
args.trust_remote_code))
176169
proc.start()
177170
procs.append(proc)
178171
exit_code = 0

tests/kernels/moe/test_batched_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
6262

6363

6464
@pytest.mark.parametrize("num_experts", [16, 32])
65-
@pytest.mark.parametrize("max_tokens_per_expert", [32, 64, 128, 192, 224, 256, 512])
65+
@pytest.mark.parametrize("max_tokens_per_expert",
66+
[32, 64, 128, 192, 224, 256, 512])
6667
@pytest.mark.parametrize("K", [128, 256, 1024])
6768
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
6869
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])

tests/kernels/moe/test_moe.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
1212

1313
import vllm.model_executor.layers.fused_moe # noqa
14-
from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe,
15-
torch_moe_single)
14+
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
1615
from vllm.config import VllmConfig, set_current_vllm_config
1716
from vllm.model_executor.layers.fused_moe import fused_moe
1817
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
@@ -27,7 +26,6 @@
2726
from vllm.model_executor.models.mixtral import MixtralMoE
2827
from vllm.platforms import current_platform
2928
from vllm.scalar_type import ScalarType, scalar_types
30-
from vllm.model_executor.layers.activation import SiluAndMul
3129

3230
NUM_EXPERTS = [8, 64]
3331
EP_SIZE = [1, 4]

tests/kernels/moe/test_pplx_moe.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
from vllm.config import VllmConfig, set_current_vllm_config
2929
from vllm.model_executor.layers.activation import SiluAndMul
3030
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
31-
BatchedDispatchCombine,
32-
BatchedExperts)
31+
BatchedDispatchCombine, BatchedExperts)
3332
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
3433
from vllm.model_executor.layers.fused_moe.modular_kernel import (
3534
FusedMoEModularKernel)
@@ -246,15 +245,9 @@ def batched_moe(a, w1, w2, topk_weight, topk_ids):
246245

247246
fused_experts = FusedMoEModularKernel(
248247
BatchedDispatchCombine(a.shape[0], world_size=1, dp_size=1, rank=0),
249-
BatchedExperts(a.shape[0])
250-
)
248+
BatchedExperts(a.shape[0]))
251249

252-
return fused_experts(a,
253-
w1,
254-
w2,
255-
topk_weight,
256-
topk_ids,
257-
num_experts)
250+
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
258251

259252

260253
# TODO: same as torch_moe but with fused_topk factored out.
@@ -301,9 +294,15 @@ def test_fused_moe_batched_experts(
301294
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
302295
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
303296

304-
torch.testing.assert_close(baseline_output, torch_output, atol=2e-2, rtol=0)
297+
torch.testing.assert_close(baseline_output,
298+
torch_output,
299+
atol=2e-2,
300+
rtol=0)
305301
torch.set_printoptions(profile="full")
306-
torch.testing.assert_close(baseline_output, batched_output, atol=2e-2, rtol=0)
302+
torch.testing.assert_close(baseline_output,
303+
batched_output,
304+
atol=2e-2,
305+
rtol=0)
307306

308307

309308
def rank_chunk(num, r, w):
@@ -585,7 +584,8 @@ def _pplx_moe(
585584
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
586585
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
587586
pplx_output = pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids)
588-
batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids)
587+
batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
588+
topk_ids)
589589

590590
torch_output = chunk_by_rank(torch_output, pgi.rank,
591591
pgi.world_size).to(pplx_output.device)

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from vllm.model_executor.layers.fused_moe.fused_moe import (
1111
get_config_dtype_str, try_get_optimal_moe_config)
1212
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
13-
from vllm.utils import direct_register_custom_op
1413

1514

1615
@triton.jit
@@ -473,7 +472,8 @@ def rank_chunk(num, r, w):
473472

474473
class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine):
475474

476-
def __init__(self, max_num_tokens: Optional[int], world_size: int, dp_size: int, rank: int):
475+
def __init__(self, max_num_tokens: Optional[int], world_size: int,
476+
dp_size: int, rank: int):
477477
super().__init__()
478478
self.world_size = world_size
479479
self.dp_size = dp_size
@@ -510,16 +510,18 @@ def dispatch(
510510
minlength=num_experts)
511511
self.max_num_tokens = int(tokens_per_expert.max().item())
512512
else:
513-
tokens_per_expert = torch.zeros(num_experts, dtype=torch.int,
513+
tokens_per_expert = torch.zeros(num_experts,
514+
dtype=torch.int,
514515
device=a1.device)
515516

516517
rem_experts = num_experts % self.world_size
517518
num_local_experts = ((num_experts // self.world_size) +
518519
(1 if self.rank < rem_experts else 0))
519520

520-
b_a1 = torch.zeros((num_local_experts, self.max_num_tokens, hidden_dim),
521-
dtype=a1.dtype,
522-
device=a1.device)
521+
b_a1 = torch.zeros(
522+
(num_local_experts, self.max_num_tokens, hidden_dim),
523+
dtype=a1.dtype,
524+
device=a1.device)
523525

524526
first_expert = (((num_experts // self.world_size) * self.rank) +
525527
rem_experts - self.rank)
@@ -540,7 +542,8 @@ def dispatch(
540542
for expert_id in range(first_expert, last_expert):
541543
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
542544
rows = torch.count_nonzero(topks.flatten())
543-
b_a1[expert_id - first_expert, :rows, :] = a1[:topks.numel()][topks]
545+
b_a1[expert_id -
546+
first_expert, :rows, :] = a1[:topks.numel()][topks]
544547
tokens_per_expert[expert_id - first_expert] = rows
545548

546549
return b_a1, a1_scale, tokens_per_expert
@@ -561,7 +564,7 @@ def combine(
561564

562565
output.fill_(0)
563566

564-
first_expert = num_local_experts * self.rank # NOT QUITE RIGHT
567+
first_expert = num_local_experts * self.rank # NOT QUITE RIGHT
565568
last_expert = first_expert + num_local_experts
566569

567570
# for expert_id in range(first_expert, last_expert):
@@ -658,8 +661,9 @@ def apply(
658661
num_experts = global_num_experts
659662
out = _resize_cache(workspace13,
660663
(num_experts, max_num_tokens * num_dp, hidden_dim))
661-
num_local_experts = w1.shape[0] #expert_num_tokens.numel()
662-
assert num_local_experts == w1.shape[0], f"{num_local_experts} == {w1.shape[0]}"
664+
num_local_experts = w1.shape[0] #expert_num_tokens.numel()
665+
assert num_local_experts == w1.shape[
666+
0], f"{num_local_experts} == {w1.shape[0]}"
663667

664668
N = w1.shape[1] // 2
665669

@@ -821,8 +825,7 @@ def apply(
821825
# invoke_batched_silu_and_mul(output=intermediate_cache2,
822826
# input=intermediate_cache1,
823827
# expert_num_tokens=expert_num_tokens)
824-
self.activation(activation,
825-
intermediate_cache2.view(-1, N//2),
828+
self.activation(activation, intermediate_cache2.view(-1, N // 2),
826829
intermediate_cache1.view(-1, N))
827830

828831
#qintermediate_cache2 = intermediate_cache2

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
_resize_cache, moe_kernel_quantize_input)
2222
from vllm.platforms import current_platform
2323
from vllm.triton_utils import tl, triton
24-
from vllm.utils import direct_register_custom_op, round_up
24+
from vllm.utils import direct_register_custom_op
2525

2626
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
2727

@@ -885,8 +885,7 @@ def fused_topk(
885885
M,
886886
topk,
887887
dtype=torch.int32 if indices_type is None else indices_type,
888-
device=hidden_states.device
889-
)
888+
device=hidden_states.device)
890889
token_expert_indices = torch.empty(M,
891890
topk,
892891
dtype=torch.int32,
@@ -980,7 +979,7 @@ def get_config_dtype_str(
980979
return None
981980

982981

983-
# TODO: use scalar_type?
982+
# TODO: use scalar_type instead of bools?
984983
def get_config_qtype(
985984
use_fp8_w8a8: bool,
986985
use_int8_w8a8: bool,
@@ -1239,8 +1238,8 @@ def fused_experts_impl(
12391238
assert hidden_states.shape[1] // 2 == w1.shape[
12401239
2], "Hidden size mismatch"
12411240
else:
1242-
assert hidden_states.shape[1] == w1.shape[2], \
1243-
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}"
1241+
assert hidden_states.shape[1] == w1.shape[2], (
1242+
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
12441243

12451244
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
12461245
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
@@ -1655,16 +1654,11 @@ def apply(
16551654
expert_ids = torch.repeat_interleave(expert_ids,
16561655
max_num_tokens,
16571656
dim=0)
1658-
print(f"EXPERT_IDS {expert_ids}")
1659-
#num_tokens_post_padded = torch.tensor([num_tokens],
1660-
# device=hidden_states.device,
1661-
# dtype=torch.int32)
16621657
num_tokens_post_padded = torch.zeros(1,
16631658
device=hidden_states.device,
16641659
dtype=torch.int32)
16651660
num_tokens_post_padded.fill_(max_num_tokens)
16661661
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
1667-
#print(f"P = {sorted_token_ids}, {hidden_states.shape}")
16681662

16691663
invoke_fused_moe_kernel(hidden_states,
16701664
w1,

0 commit comments

Comments
 (0)