Skip to content

Commit a674762

Browse files
committed
Varun's fixes/cleanups
Signed-off-by: Bill Nell <[email protected]>
1 parent 9f8e241 commit a674762

File tree

11 files changed

+55
-205
lines changed

11 files changed

+55
-205
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 1 addition & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import triton.language as tl
88

99
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
10-
invoke_batched_silu_and_mul, invoke_moe_batched_triton_kernel)
10+
invoke_moe_batched_triton_kernel)
1111

1212

1313
@dataclass
@@ -103,75 +103,5 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
103103

104104
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
105105
tensors.num_expert_tokens)
106-
#torch.cuda.synchronize()
107-
#print (f"ref output {ref_output}")
108-
#print (f"test output {test_output}")
109106

110107
torch.testing.assert_close(test_output, ref_output, atol=1e-3, rtol=1e-3)
111-
112-
113-
@dataclass
114-
class BatchedSiluMulConfig:
115-
dtype: torch.dtype
116-
num_experts: int
117-
max_tokens_per_expert: int
118-
D: int
119-
120-
121-
@dataclass
122-
class BatchedSiluMulTensors:
123-
input: torch.Tensor
124-
output: torch.Tensor
125-
expert_num_tokens: torch.Tensor
126-
127-
@staticmethod
128-
def make_tensors(config: BatchedSiluMulConfig):
129-
input = torch.randn(
130-
(config.num_experts, config.max_tokens_per_expert, config.D * 2),
131-
device="cuda",
132-
dtype=config.dtype) / 50.0
133-
output = torch.zeros(
134-
(config.num_experts, config.max_tokens_per_expert, config.D),
135-
device="cuda",
136-
dtype=config.dtype)
137-
num_expert_tokens = torch.randint(low=0,
138-
high=config.max_tokens_per_expert,
139-
size=(config.num_experts, ),
140-
device="cuda",
141-
dtype=torch.int32)
142-
return BatchedSiluMulTensors(input, output, num_expert_tokens)
143-
144-
145-
def ref_batched_silu_mul(output: torch.Tensor, input: torch.Tensor,
146-
num_expert_tokens: torch.Tensor) -> torch.Tensor:
147-
148-
num_expert_tokens_cpu = num_expert_tokens.clone()
149-
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
150-
num_experts = num_expert_tokens.size(0)
151-
152-
for e in range(num_experts):
153-
num_tokens = num_expert_tokens_cpu[e].item()
154-
out_part = output[e, :num_tokens, :]
155-
in_part = input[e, :num_tokens, :]
156-
torch.ops._C.silu_and_mul(out_part, in_part)
157-
158-
159-
@pytest.mark.parametrize("num_experts", [16, 32])
160-
@pytest.mark.parametrize("max_tokens_per_expert", [128])
161-
@pytest.mark.parametrize("D", [128, 256])
162-
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
163-
def test_batched_silu_mul(num_experts: int, max_tokens_per_expert: int, D: int,
164-
dtype: torch.dtype):
165-
166-
config = BatchedSiluMulConfig(dtype, num_experts, max_tokens_per_expert, D)
167-
tensors = BatchedSiluMulTensors.make_tensors(config)
168-
169-
test_out = tensors.output
170-
ref_out = torch.zeros_like(test_out)
171-
172-
ref_batched_silu_mul(ref_out, tensors.input, tensors.expert_num_tokens)
173-
174-
invoke_batched_silu_and_mul(test_out, tensors.input,
175-
tensors.expert_num_tokens)
176-
177-
torch.testing.assert_close(test_out, ref_out)

vllm/distributed/parallel_state.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,9 @@ def pplx_finalize():
968968
if PPLX_DID_INIT:
969969
from pplx_kernels.nvshmem import nvshmem_finalize
970970
logger.info("PPLX finalize")
971+
from vllm.model_executor.layers.fused_moe.layer import (
972+
_all_to_all_cache)
973+
_all_to_all_cache.destroy()
971974
nvshmem_finalize()
972975

973976

vllm/distributed/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import vllm.envs as envs
2525
from vllm.logger import init_logger
26-
from vllm.utils import get_tcp_uri
26+
from vllm.utils import get_tcp_uri, is_torch_equal_or_newer
2727

2828
logger = init_logger(__name__)
2929

@@ -362,11 +362,11 @@ def stateless_destroy_torch_distributed_process_group(
362362
Destroy ProcessGroup returned by
363363
stateless_init_torch_distributed_process_group().
364364
"""
365-
# TODO: pytorch < 2.7?
366-
if False:
365+
if is_torch_equal_or_newer("2.7"):
366+
pg.shutdown()
367+
else:
367368
# Lazy import for non-CUDA backends.
368369
from torch.distributed.distributed_c10d import _shutdown_backend
369370
_shutdown_backend(pg)
370-
else:
371-
pg.shutdown()
371+
372372
_unregister_process_group(pg.group_name)

vllm/forward_context.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@
2727

2828
@dataclass
2929
class DPMetadata:
30-
max_tokens_across_dp: torch.Tensor
31-
num_tokens_across_dp: torch.Tensor
30+
max_tokens_across_dp_cpu: torch.Tensor
3231
cu_tokens_across_dp_cpu: torch.Tensor
33-
dp_rank_num_tokens: torch.Tensor
3432

3533

3634
@dataclass
@@ -93,16 +91,10 @@ def set_forward_context(attn_metadata: Any,
9391
dtype=torch.int32)
9492
from vllm.distributed.parallel_state import get_dp_group
9593
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
96-
#TODO device? (tms)
97-
max_tokens_across_dp = torch.max(
98-
num_tokens_tensor) #.to(device="cuda")
94+
max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
9995
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
100-
dp_rank_num_tokens = torch.tensor(
101-
[num_tokens],
102-
dtype=torch.uint32,
103-
device=vllm_config.device_config.device)
104-
dp_metadata = DPMetadata(max_tokens_across_dp, num_tokens_tensor,
105-
cu_tokens_across_dp_cpu, dp_rank_num_tokens)
96+
dp_metadata = DPMetadata(max_tokens_across_dp_cpu,
97+
cu_tokens_across_dp_cpu)
10698

10799
global _forward_context
108100
prev_context = _forward_context

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 4 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -12,65 +12,6 @@
1212
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
1313

1414

15-
@triton.jit
16-
def batched_silu_and_mul_kernel(
17-
output, # [E, MAX_NUM_TOKENS, D]
18-
input, # [E, MAX_NUM_TOKENS, D * 2]
19-
expert_num_tokens, # [E]
20-
stride_oe,
21-
stride_om,
22-
stride_ie,
23-
stride_im,
24-
compute_type: tl.constexpr,
25-
D,
26-
BLOCK_M: tl.constexpr,
27-
BLOCK_D: tl.constexpr):
28-
29-
expert_id = tl.program_id(axis=0)
30-
e_num_tokens = tl.load(expert_num_tokens + expert_id)
31-
if e_num_tokens == 0:
32-
# early exit
33-
return
34-
35-
pid_m = tl.program_id(axis=1)
36-
cta_m_start = pid_m * BLOCK_M
37-
if cta_m_start >= e_num_tokens:
38-
# early exit
39-
return
40-
41-
cta_input_ptr = input + expert_id * stride_ie + cta_m_start * stride_im
42-
cta_output_ptr = output + expert_id * stride_oe + cta_m_start * stride_om
43-
44-
cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start)
45-
offs_m = tl.arange(0, BLOCK_M)[:, None]
46-
mask_m = offs_m < cta_m_size
47-
48-
cta_input_ptrs = cta_input_ptr + offs_m * stride_im
49-
cta_output_ptrs = cta_output_ptr + offs_m * stride_om
50-
51-
# offset by D
52-
offs_D = tl.arange(0, BLOCK_D)
53-
cta_input_ptrs = cta_input_ptrs + offs_D
54-
cta_output_ptrs = cta_output_ptrs + offs_D
55-
56-
for d in range(0, tl.cdiv(D, BLOCK_D)):
57-
mask_D = offs_D < (D - (d * BLOCK_D))
58-
mask_tile = mask_m & mask_D
59-
60-
x_tile = tl.load(cta_input_ptrs, mask=mask_tile,
61-
other=0.0).to(dtype=tl.float32)
62-
y_tile = tl.load(cta_input_ptrs + D, mask=mask_tile, other=0.0)
63-
64-
# silu and mul
65-
out_tile = (x_tile * (1.0 /
66-
(1.0 + tl.exp(-x_tile)))).to(dtype=compute_type)
67-
out_tile = out_tile * y_tile
68-
tl.store(cta_output_ptrs, out_tile, mask=mask_tile)
69-
70-
cta_input_ptrs = cta_input_ptrs + BLOCK_D
71-
cta_output_ptrs = cta_output_ptrs + BLOCK_D
72-
73-
7415
@triton.jit
7516
def moe_mmk(
7617
a_ptrs,
@@ -438,33 +379,6 @@ def invoke_moe_batched_triton_kernel(
438379
BLOCK_K=BLOCK_K)
439380

440381

441-
def invoke_batched_silu_and_mul(
442-
output: torch.Tensor, #[E, MAX_TOKENS, D]
443-
input: torch.Tensor, #[E, MAX_TOKENS, D * 2]
444-
expert_num_tokens: torch.Tensor):
445-
446-
num_experts = output.size(0)
447-
max_num_tokens = output.size(1)
448-
D = output.size(2)
449-
450-
BLOCK_D = 1024
451-
BLOCK_M = 1
452-
453-
compute_tl_dtype = {
454-
torch.float16: tl.float16,
455-
torch.float32: tl.float32,
456-
torch.bfloat16: tl.bfloat16
457-
}[output.dtype]
458-
459-
#print(f"compute type {compute_tl_dtype}")
460-
461-
grid = (num_experts, triton.cdiv(max_num_tokens, BLOCK_M))
462-
batched_silu_and_mul_kernel[grid](output, input, expert_num_tokens,
463-
output.stride(0), output.stride(1),
464-
input.stride(0), input.stride(1),
465-
compute_tl_dtype, D, BLOCK_M, BLOCK_D)
466-
467-
468382
def rank_chunk(num, r, w):
469383
rem = num % w
470384
return (num // w) + (1 if r < rem else 0)
@@ -797,15 +711,10 @@ def apply(
797711
config=config,
798712
block_shape=self.block_shape)
799713

800-
if activation == "silu":
801-
invoke_batched_silu_and_mul(output=intermediate_cache2,
802-
input=intermediate_cache1,
803-
expert_num_tokens=expert_num_tokens)
804-
else:
805-
# TODO: would be nice to use expert_num_tokens here to reduce
806-
# garbage compute
807-
self.activation(activation, intermediate_cache2.view(-1, N // 2),
808-
intermediate_cache1.view(-1, N))
714+
# TODO: would be nice to use expert_num_tokens here to reduce
715+
# garbage compute
716+
self.activation(activation, intermediate_cache2.view(-1, N // 2),
717+
intermediate_cache1.view(-1, N))
809718

810719
#qintermediate_cache2 = intermediate_cache2
811720
a2q_scale = a2_scale

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class FusedMoEParallelConfig:
7070

7171
@property
7272
def use_pplx_kernels(self):
73-
return self.use_ep and has_pplx
73+
return self.dp_size > 1 and self.use_ep and has_pplx
7474

7575
@staticmethod
7676
def make(tp_size_: int, dp_size_: int,
@@ -277,6 +277,12 @@ def __init__(self):
277277
self._cache: WeakValueDictionary = WeakValueDictionary()
278278
self._lock = threading.RLock() # Reentrant lock for thread safety
279279

280+
def destroy(self):
281+
with self._lock:
282+
# TODO: can we do del self._cache?
283+
for _, a2a in self._cache.items():
284+
a2a.destroy()
285+
280286
def get_or_create(self, **kwargs):
281287
assert has_pplx
282288
import pplx_kernels as pplx
@@ -287,7 +293,9 @@ def get_or_create(self, **kwargs):
287293
with self._lock:
288294
instance = self._cache.get(key)
289295
if instance is None:
290-
# TODO: should be intranode
296+
# TODO (varun): Add support to switch to intranode
297+
# when all communications are within the same
298+
# node.
291299
instance = pplx.AllToAll.internode(**kwargs)
292300
self._cache[key] = instance
293301
return instance
@@ -676,7 +684,7 @@ def _construct_dispatch_combine(
676684
dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP.
677685
rank = moe.ep_rank
678686

679-
if moe.use_ep and has_pplx:
687+
if moe.use_pplx_kernels:
680688
logger.debug("using pplx dispatch")
681689

682690
all_to_all = get_all_to_all(
@@ -1236,17 +1244,27 @@ def naive_multicast(self, x: torch.Tensor,
12361244

12371245
return buffer
12381246

1239-
def must_reduce_shared_outputs(self) -> bool:
1240-
return self.dp_size > 1 and self.use_ep and has_pplx
1247+
def must_reduce_shared_expert_outputs(self) -> bool:
1248+
"""
1249+
The shared_experts are typically computed using the RowParallelLinear
1250+
layer. The result of this function is typically used as
1251+
the reduce_results argument to the module.
1252+
When just tensor-parallel is used, it is not required to reduce
1253+
the shared_experts results immediately. Instead we reduce at the
1254+
once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
1255+
With EP and the pplx kernels - this is no longer viable as all
1256+
GPU ranks in DP, produce the complete set of hidden_states.
1257+
Therefore it is required that we reduce the shared_experts output
1258+
early.
1259+
"""
1260+
return self.use_pplx_kernels
12411261

12421262
def maybe_all_reduce_tensor_model_parallel(
12431263
self, final_hidden_states: torch.Tensor):
12441264
"""
1245-
The pplx combine kernel reduce across GPU ranks by default. The pplx
1246-
kernels are used when EP is enabled. In that case, this function is a
1247-
no-op.
1265+
The pplx combine kernel reduces across GPU ranks by default.
12481266
"""
1249-
if self.dp_size > 1 and self.use_ep and has_pplx:
1267+
if self.use_pplx_kernels:
12501268
return final_hidden_states
12511269
else:
12521270
return tensor_model_parallel_all_reduce(final_hidden_states)
@@ -1291,7 +1309,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
12911309
final_hidden_states)
12921310

12931311
ctx = get_forward_context()
1294-
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp
1312+
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
12951313
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE
12961314

12971315
num_tokens = full_hidden_states.size(0)
@@ -1313,7 +1331,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
13131331
def forward_impl(self, hidden_states: torch.Tensor,
13141332
router_logits: torch.Tensor):
13151333
assert self.quant_method is not None
1316-
if self.dp_size > 1 and self.use_ep and has_pplx:
1334+
if self.moe_parallel_config.use_pplx_kernels:
13171335
return self.forward_impl_chunked(hidden_states, router_logits)
13181336

13191337
if self.dp_size > 1:

vllm/model_executor/models/deepseek_v2.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,8 @@ def __init__(
141141
intermediate_size=intermediate_size,
142142
hidden_act=config.hidden_act,
143143
quant_config=quant_config,
144-
# When just tensor-parallel is used, it isn't required
145-
# to reduce the shared_output result. Instead we reduce
146-
# at the end of the forward pass.
147-
# With EP and the pplx kernels - this is no longer viable
148-
# as all GPU ranks in DP, produce the complete set of
149-
# hidden_states.
150-
# Therefore reduce the shared experts early.
151-
reduce_results=self.experts.must_reduce_shared_outputs(),
144+
reduce_results=self.experts.must_reduce_shared_expert_outputs(
145+
),
152146
prefix=f"{prefix}.shared_experts",
153147
)
154148

vllm/model_executor/models/llama4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(self,
8888
quant_config=quant_config,
8989
bias=False,
9090
prefix=f"{prefix}.shared_expert",
91-
reduce_results=False, # We need to do scatter before reduce
91+
reduce_results=self.experts.must_reduce_shared_expert_outputs(),
9292
)
9393

9494
def forward(self, hidden_states):

0 commit comments

Comments
 (0)