Skip to content

Commit 6c0e085

Browse files
committed
cosmetic changes
Signed-off-by: Bill Nell <[email protected]>
1 parent 9f0ea4f commit 6c0e085

File tree

9 files changed

+88
-89
lines changed

9 files changed

+88
-89
lines changed

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def _valid_deep_gemm(hidden_states: torch.Tensor,
5050
logger.debug("DeepGemm disabled: expert map NYI.")
5151
return False
5252

53-
M = hidden_states.shape[0]
54-
_, K, N = w2.shape
53+
M = hidden_states.size(0)
54+
_, K, N = w2.size()
5555
if not _valid_deep_gemm_shape(M, N, K):
5656
logger.debug("DeepGemm disabled: unalinged problem size.")
5757
return False
@@ -113,10 +113,10 @@ def apply(
113113
import deep_gemm as dg
114114

115115
a1q = hidden_states
116-
_, N, K = w1.shape
116+
_, N, K = w1.size()
117117

118118
assert global_num_experts != -1
119-
assert w2.shape[1] == K
119+
assert w2.size(1) == K
120120

121121
a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute(
122122
a1q,
@@ -128,7 +128,7 @@ def apply(
128128
)
129129

130130
# Note: M_sum is different than the pre-permuted shape of a1q.
131-
M_sum = a1q.shape[0]
131+
M_sum = a1q.size(0)
132132
workspace1 = _resize_cache(workspace13, (M_sum, N))
133133
workspace2 = _resize_cache(workspace2, (M_sum, N // 2))
134134
workspace3 = _resize_cache(workspace13, (M_sum, K))

vllm/model_executor/layers/fused_moe/dispatch_combine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def dispatch(
3535
apply_router_weight_on_input: bool = False,
3636
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
3737
if apply_router_weight_on_input:
38-
topk = topk_ids.shape[1]
38+
topk = topk_ids.size(1)
3939
# TODO: this only works for topK=1, will need to update for topK>1
4040
assert topk == 1, \
4141
"apply_router_weight_on_input is only implemented for topk=1"

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def invoke_moe_batched_triton_kernel(
395395
assert max_num_tokens % BLOCK_M == 0
396396

397397
grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
398-
triton.cdiv(B.shape[1], BLOCK_N))
398+
triton.cdiv(B.size(1), BLOCK_N))
399399

400400
batched_triton_kernel[grid](
401401
A,
@@ -493,17 +493,17 @@ def dispatch(
493493
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
494494
assert a1.dim() == 2
495495
assert topk_ids.dim() == 2
496-
assert topk_ids.shape[0] == a1.shape[0]
496+
assert topk_ids.size(0) == a1.size(0)
497497

498498
if apply_router_weight_on_input:
499-
topk = topk_ids.shape[1]
499+
topk = topk_ids.size(1)
500500
# TODO: this only works for topK=1, will need to update for topK>1
501501
assert topk == 1, \
502502
"apply_router_weight_on_input is only implemented for topk=1"
503503
a1.mul_(topk_weights.to(a1.dtype))
504504

505-
num_tokens, hidden_dim = a1.shape
506-
topk = topk_ids.shape[1]
505+
num_tokens, hidden_dim = a1.size()
506+
topk = topk_ids.size(1)
507507

508508
if self.max_num_tokens is None:
509509
tokens_per_expert = torch.bincount(topk_ids.view(-1),
@@ -543,10 +543,10 @@ def combine(
543543
topk_ids: torch.Tensor,
544544
apply_router_weight_on_input: bool,
545545
) -> None:
546-
num_tokens = topk_ids.shape[0]
547-
num_local_experts = fused_expert_output.shape[0]
548-
K = fused_expert_output.shape[-1]
549-
assert output.shape[0] == num_tokens and output.shape[1] == K
546+
num_tokens = topk_ids.size(0)
547+
num_local_experts = fused_expert_output.size(0)
548+
K = fused_expert_output.size(-1)
549+
assert output.size(0) == num_tokens and output.size(1) == K
550550

551551
output.fill_(0)
552552

@@ -559,7 +559,7 @@ def combine(
559559
rows = torch.count_nonzero(topks)
560560
rhs = fused_expert_output[expert_id - first_expert, :rows, :]
561561
if not apply_router_weight_on_input:
562-
rhs.mul_(topk_weights[topkws].view(rhs.shape[0], 1))
562+
rhs.mul_(topk_weights[topkws].view(rhs.size(0), 1))
563563
output[topks] = output[topks] + rhs
564564

565565

@@ -599,8 +599,8 @@ def workspace_shapes(
599599
) -> Tuple[int, int, torch.dtype]:
600600
assert a.dim() == 2
601601
num_dp = self.world_size // self.dp_size
602-
max_num_tokens = a.shape[
603-
0] if self.max_num_tokens is None else self.max_num_tokens
602+
max_num_tokens = a.size(
603+
0) if self.max_num_tokens is None else self.max_num_tokens
604604
#print(f"WORKSPACE {max_num_tokens} {num_dp}")
605605
workspace13 = num_experts * max_num_tokens * num_dp * K
606606
workspace2 = max_num_tokens * num_dp * N
@@ -627,27 +627,27 @@ def apply(
627627
) -> torch.Tensor:
628628
assert hidden_states.dim() == 3
629629
assert expert_num_tokens is not None
630-
hidden_dim = hidden_states.shape[-1]
630+
hidden_dim = hidden_states.size(-1)
631631

632632
if self.max_num_tokens is None:
633-
max_num_tokens = hidden_states.shape[1]
633+
max_num_tokens = hidden_states.size(1)
634634
else:
635635
max_num_tokens = self.max_num_tokens
636636

637637
num_dp = self.world_size // self.dp_size
638638
num_experts = global_num_experts
639639
out = _resize_cache(workspace13,
640640
(num_experts, max_num_tokens * num_dp, hidden_dim))
641-
num_local_experts = w1.shape[0]
642-
assert num_local_experts == w1.shape[
643-
0], f"{num_local_experts} == {w1.shape[0]}"
641+
num_local_experts = w1.size(0)
642+
assert num_local_experts == w1.size(0), (
643+
f"{num_local_experts} == {w1.size(0)}")
644644

645-
N = w1.shape[1] // 2
645+
N = w1.size(1) // 2
646646

647647
# Not cudagraph friendly
648648
assert (torch.cuda.is_current_stream_capturing()
649-
or torch.all(expert_num_tokens <= max_num_tokens)), (
650-
f"{expert_num_tokens} <= {max_num_tokens}")
649+
or torch.all(expert_num_tokens <= max_num_tokens * num_dp)), (
650+
f"{expert_num_tokens} <= {max_num_tokens * num_dp}")
651651

652652
for expert in range(num_local_experts):
653653
# Indexing expert_num_tokens doesn't work w/cudagraphs
@@ -699,8 +699,8 @@ def workspace_shapes(
699699
) -> Tuple[int, int, torch.dtype]:
700700
assert a.dim() == 2
701701
num_dp = self.world_size // self.dp_size
702-
max_num_tokens = a.shape[
703-
0] if self.max_num_tokens is None else self.max_num_tokens
702+
max_num_tokens = a.size(
703+
0) if self.max_num_tokens is None else self.max_num_tokens
704704
workspace13 = num_experts * max_num_tokens * num_dp * max(K, N)
705705
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2)
706706
return (workspace13, workspace2, a.dtype)
@@ -726,12 +726,12 @@ def apply(
726726
) -> torch.Tensor:
727727
# Check constraints.
728728
if self.use_int4_w4a16:
729-
assert hidden_states.shape[-1] // 2 == w1.shape[
730-
2], "Hidden size mismatch"
729+
assert hidden_states.size(-1) // 2 == w1.size(2), (
730+
"Hidden size mismatch")
731731
else:
732-
assert hidden_states.shape[-1] == w1.shape[2], \
733-
(f"Hidden size mismatch {hidden_states.shape[-1]} "
734-
f"!= {w1.shape[2]}")
732+
assert hidden_states.size(-1) == w1.size(2), (
733+
f"Hidden size mismatch {hidden_states.size(-1)} "
734+
f"!= {w1.size(2)}")
735735

736736
assert hidden_states.is_contiguous(
737737
), "Hidden_states must be contiguous"
@@ -745,17 +745,17 @@ def apply(
745745
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
746746
hidden_states, w1, w2, topk_ids)
747747

748-
assert w1.shape[0] == E
749-
assert w2.shape[0] == E
748+
assert w1.size(0) == E
749+
assert w2.size(0) == E
750750

751751
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
752752
use_int8_w8a16=self.use_int8_w8a16,
753753
use_int4_w4a16=self.use_int4_w4a16,
754754
dtype=hidden_states.dtype)
755755

756756
config = try_get_optimal_moe_config(
757-
w1.shape,
758-
w2.shape,
757+
w1.size(),
758+
w2.size(),
759759
top_k_num,
760760
config_dtype,
761761
num_tokens,
@@ -797,13 +797,13 @@ def apply(
797797
config=config,
798798
block_shape=self.block_shape)
799799

800-
# Fix activations
801-
if True:
802-
assert activation == "silu"
800+
if activation == "silu":
803801
invoke_batched_silu_and_mul(output=intermediate_cache2,
804802
input=intermediate_cache1,
805803
expert_num_tokens=expert_num_tokens)
806804
else:
805+
# TODO: would be nice to use expert_num_tokens here to reduce
806+
# garbage compute
807807
self.activation(activation, intermediate_cache2.view(-1, N // 2),
808808
intermediate_cache1.view(-1, N))
809809

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,12 +1576,12 @@ def apply(
15761576
) -> torch.Tensor:
15771577
# Check constraints.
15781578
if self.use_int4_w4a16:
1579-
assert hidden_states.shape[-1] // 2 == w1.shape[
1580-
2], "Hidden size mismatch"
1579+
assert hidden_states.size(-1) // 2 == w1.size(2), (
1580+
"Hidden size mismatch")
15811581
else:
1582-
assert hidden_states.shape[-1] == w1.shape[2], \
1583-
(f"Hidden size mismatch {hidden_states.shape[-1]} "
1584-
f"!= {w1.shape[2]}")
1582+
assert hidden_states.size(-1) == w1.size(2), \
1583+
(f"Hidden size mismatch {hidden_states.size(-1)} "
1584+
f"!= {w1.size(2)}")
15851585

15861586
assert hidden_states.is_contiguous(
15871587
), "Hidden_states must be contiguous"
@@ -1637,9 +1637,9 @@ def apply(
16371637
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
16381638
global_num_experts, expert_map))
16391639
else:
1640-
max_num_tokens = hidden_states.shape[1]
1640+
max_num_tokens = hidden_states.size(1)
16411641
sorted_token_ids = torch.arange(0,
1642-
hidden_states.shape[0] *
1642+
hidden_states.size(0) *
16431643
max_num_tokens,
16441644
device=hidden_states.device,
16451645
dtype=torch.int)
@@ -1655,7 +1655,7 @@ def apply(
16551655
device=hidden_states.device,
16561656
dtype=torch.int32)
16571657
num_tokens_post_padded.fill_(max_num_tokens)
1658-
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
1658+
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
16591659

16601660
invoke_fused_moe_kernel(hidden_states,
16611661
w1,

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,21 +57,21 @@ def _moe_problem_size(
5757
to be kept in mind.
5858
"""
5959
assert w1.dim() == 3 and w2.dim() == 3
60-
E, N, _ = w1.shape
61-
K = w2.shape[1]
60+
E, N, _ = w1.size()
61+
K = w2.size(1)
6262

6363
if a1.dim() == 2:
6464
# Make sure we are using the correct a1 (pre-permute).
65-
assert topk_ids.shape[0] == a1.shape[0], \
66-
f"{topk_ids.shape[0]} != {a1.shape[0]}"
67-
M = a1.shape[0]
65+
assert topk_ids.size(0) == a1.size(0), \
66+
f"{topk_ids.size(0)} != {a1.size(0)}"
67+
M = a1.size(0)
6868
else:
6969
assert a1.dim() == 3
70-
assert a1.shape[0] == E, f"{a1.shape[0]} == {E}"
71-
M = a1.shape[1] # This is max_num_tokens
70+
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
71+
M = a1.size(1) # This is max_num_tokens
7272

7373
assert topk_ids.dim() == 2
74-
topk = topk_ids.shape[1]
74+
topk = topk_ids.size(1)
7575

7676
return E, M, N, K, topk
7777

@@ -171,7 +171,7 @@ def workspace_shapes(
171171

172172
def activation(self, activation: str, output: torch.Tensor,
173173
input: torch.Tensor) -> None:
174-
assert output.shape[-1] * 2 == input.shape[-1]
174+
assert output.size(-1) * 2 == input.size(-1)
175175
if activation == "silu":
176176
torch.ops._C.silu_and_mul(output, input)
177177
elif activation == "gelu":
@@ -320,18 +320,18 @@ def forward(
320320
if global_num_experts == -1:
321321
global_num_experts = E
322322

323-
output = a1 if inplace else torch.empty_like(a1)
323+
output = a1 if inplace else torch.zeros_like(a1)
324324

325325
workspace13_shape, workspace2_shape, workspace_dtype = (
326326
self.fused_experts.workspace_shapes(a1, M, N, K, top_k,
327327
global_num_experts))
328328

329329
# We can reuse the memory between cache1 and cache3 because by the time
330330
# we need cache3, we're done with cache1
331-
workspace13 = torch.empty(workspace13_shape,
331+
workspace13 = torch.zeros(workspace13_shape,
332332
device=a1.device,
333333
dtype=workspace_dtype)
334-
workspace2 = torch.empty(workspace2_shape,
334+
workspace2 = torch.zeros(workspace2_shape,
335335
device=a1.device,
336336
dtype=workspace_dtype)
337337

vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ def _moe_permute(
2222
Determine the sorted_token_ids, expert_ids for the given problem size.
2323
Permute the hidden states and scales according to `sorted_token_ids`.
2424
"""
25-
top_k_num = curr_topk_ids.shape[1]
25+
top_k_num = curr_topk_ids.size(1)
2626

27-
tokens_in_chunk, _ = curr_hidden_states.shape
27+
tokens_in_chunk = curr_hidden_states.sizze(0)
2828

2929
sorted_token_ids, expert_ids, num_tokens_post_padded = (
3030
moe_align_block_size(curr_topk_ids,
@@ -62,8 +62,8 @@ def _moe_unpermute_and_reduce(
6262
Unpermute the final result and apply topk_weights, then perform the final
6363
reduction on the hidden states.
6464
"""
65-
M, topk = topk_weight.shape
66-
K = curr_hidden.shape[-1]
65+
M, topk = topk_weight.size()
66+
K = curr_hidden.size(-1)
6767
if inv_perm is not None:
6868
curr_hidden = curr_hidden[inv_perm, ...]
6969
curr_hidden = curr_hidden.view(-1, topk, K)
@@ -110,7 +110,7 @@ def moe_permute(
110110
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
111111
the group which the j-th row of the LHS belong to.`
112112
"""
113-
n_token, n_hidden = hidden_states.shape
113+
n_token, n_hidden = hidden_states.size()
114114
assert (n_hidden * hidden_states.element_size()
115115
) % 16 == 0, "permue kernel need hidden dim align to 16B"
116116
permuted_row_size = n_token * topk
@@ -170,7 +170,7 @@ def moe_unpermute(
170170
- hidden_states (torch.Tensor): The reduced and unpermuted activation
171171
tensor.
172172
"""
173-
n_token, n_hidden = topk_weights.shape[0], permuted_hidden_states.shape[-1]
173+
n_token, n_hidden = topk_weights.size(0), permuted_hidden_states.size(-1)
174174
assert (n_hidden * permuted_hidden_states.element_size()
175175
) % 16 == 0, "unpermue kernel need hidden dim align to 16B"
176176
hidden_states = torch.empty((n_token, n_hidden),

0 commit comments

Comments
 (0)