Skip to content

Commit 44124af

Browse files
committed
simplify a2a kernel dispatching
Signed-off-by: Sage Moore <[email protected]>
1 parent b6d162f commit 44124af

File tree

4 files changed

+14
-36
lines changed

4 files changed

+14
-36
lines changed

vllm/distributed/device_communicators/all2all.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -101,25 +101,14 @@ def __init__(self, cpu_group):
101101
logger.debug("PPLX NVSHMEM UID = %s", uid)
102102
nvshmem_init(uid, self.rank, self.world_size)
103103

104-
# self.handle_cache = Cache()
105-
self.handle_caches = [Cache(), Cache()]
104+
self.handle_cache = Cache()
106105

107106
def get_handle(self, kwargs):
108107
import pplx_kernels as pplx
109-
return self.handle_caches[0].get_or_create(
108+
return self.handle_cache.get_or_create(
110109
kwargs, pplx.AllToAll.internode
111110
if self.internode else pplx.AllToAll.intranode)
112111

113-
def get_handles(self, kwargs):
114-
import pplx_kernels as pplx
115-
first_handle = self.handle_caches[0].get_or_create(
116-
kwargs, pplx.AllToAll.internode
117-
if self.internode else pplx.AllToAll.intranode)
118-
second_handle = self.handle_caches[1].get_or_create(
119-
kwargs, pplx.AllToAll.internode
120-
if self.internode else pplx.AllToAll.intranode)
121-
return [first_handle, second_handle]
122-
123112
def dispatch(self, hidden_states: torch.Tensor,
124113
router_logits: torch.Tensor):
125114
raise NotImplementedError
@@ -128,10 +117,9 @@ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
128117
raise NotImplementedError
129118

130119
def destroy(self):
131-
for handle_cache in self.handle_caches:
132-
with handle_cache._lock:
133-
for _, handle in handle_cache._cache.items():
134-
handle.destroy()
120+
with self.handle_cache._lock:
121+
for _, handle in self.handle_cache._cache.items():
122+
handle.destroy()
135123

136124
if self.internode:
137125
from pplx_kernels.nvshmem import nvshmem_finalize
@@ -148,7 +136,7 @@ def __init__(self, cpu_group):
148136
assert has_deep_ep(
149137
), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
150138
super().__init__(cpu_group)
151-
self.handle_caches = [Cache(), Cache()]
139+
self.handle_cache = Cache()
152140

153141
# This is the DeepEP default. Stick to it till we can establish
154142
# reasonable defaults based on profiling.
@@ -175,7 +163,6 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
175163

176164
def __init__(self, cpu_group):
177165
super().__init__(cpu_group)
178-
self.handle_cache = self.handle_caches[0]
179166

180167
def _make_all2all_kwargs(self) -> dict[Any, Any]:
181168
# Defaults for internode and intranode are taken from DeepEP tests.
@@ -224,7 +211,6 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
224211

225212
def __init__(self, cpu_group):
226213
super().__init__(cpu_group)
227-
self.handle_cache = self.handle_caches[0]
228214

229215
def _make_all2all_kwargs(
230216
self,
@@ -271,8 +257,3 @@ def get_handle(self, kwargs):
271257
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
272258
buffer_kwargs, deep_ep.Buffer)
273259
return handle
274-
275-
def get_handles(self, kwargs):
276-
handle = self.get_handle(kwargs)
277-
# For DeepEP we use the same handle for microbatching
278-
return [handle, handle]

vllm/distributed/device_communicators/base_device_communicator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,6 @@ def get_handle(self, kwargs):
6060
# and reuse it for the same config.
6161
raise NotImplementedError
6262

63-
def get_handles(self, kwargs):
64-
raise NotImplementedError
65-
6663
def dispatch(self, hidden_states: torch.Tensor,
6764
router_logits: torch.Tensor):
6865
raise NotImplementedError

vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
4848
SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 6144, 7168]
4949

5050
def __init__(self,
51-
buffers: list[deep_ep.Buffer],
51+
buffer: deep_ep.Buffer,
5252
max_tokens_per_rank: int,
5353
num_dispatchers: int,
5454
use_fp8_dispatch: bool = False):
5555
super().__init__()
5656

57-
self.buffers = buffers
57+
self.buffer = buffer
5858
self.max_tokens_per_rank = max_tokens_per_rank
5959
self.use_fp8_dispatch = use_fp8_dispatch
6060
# The dispatch function returns a handle that the combine function
@@ -154,7 +154,7 @@ def prepare(
154154
# Dispatch
155155
dbo_maybe_run_recv_hook()
156156
expert_x, expert_num_tokens, handle, _, recv_hook= \
157-
self.buffers[a2a_idx].low_latency_dispatch(a1,
157+
self.buffer.low_latency_dispatch(a1,
158158
topk_ids,
159159
self.max_tokens_per_rank,
160160
num_experts,
@@ -200,7 +200,7 @@ def finalize(
200200

201201
# TODO (varun) : Enable zero copy mode
202202
dbo_maybe_run_recv_hook()
203-
_, _, recv_hook = self.buffers[a2a_idx].low_latency_combine(fused_expert_output,
203+
_, _, recv_hook = self.buffer.low_latency_combine(fused_expert_output,
204204
topk_ids,
205205
combine_topk_weights,
206206
handle,

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,10 @@ def _maybe_make_prepare_finalize(
142142
all_to_all_args[
143143
"group_name"] = all2all_manager.cpu_group.group_name
144144

145-
handles = all2all_manager.get_handles(all_to_all_args)
145+
handle = all2all_manager.get_handle(all_to_all_args)
146146

147147
prepare_finalize = PplxPrepareAndFinalize(
148-
handles,
148+
handle,
149149
max_num_tokens=moe.max_num_tokens,
150150
num_local_experts=moe.num_local_experts,
151151
num_dispatchers=num_dispatchers,
@@ -171,7 +171,7 @@ def _maybe_make_prepare_finalize(
171171
num_global_experts=moe.num_experts,
172172
num_local_experts=moe.num_experts //
173173
all2all_manager.world_size)
174-
handles = all2all_manager.get_handles(all_to_all_args)
174+
handle = all2all_manager.get_handle(all_to_all_args)
175175

176176
# Note : We may want to use FP8 dispatch even otherwise just to
177177
# reduce datamovement
@@ -182,7 +182,7 @@ def _maybe_make_prepare_finalize(
182182
== DEEPEP_QUANT_BLOCK_SHAPE)
183183

184184
prepare_finalize = DeepEPLLPrepareAndFinalize(
185-
handles,
185+
handle,
186186
max_tokens_per_rank=moe.max_num_tokens,
187187
num_dispatchers=all2all_manager.world_size,
188188
use_fp8_dispatch=use_fp8_dispatch,

0 commit comments

Comments
 (0)