@@ -101,25 +101,14 @@ def __init__(self, cpu_group):
101
101
logger .debug ("PPLX NVSHMEM UID = %s" , uid )
102
102
nvshmem_init (uid , self .rank , self .world_size )
103
103
104
- # self.handle_cache = Cache()
105
- self .handle_caches = [Cache (), Cache ()]
104
+ self .handle_cache = Cache ()
106
105
107
106
def get_handle (self , kwargs ):
108
107
import pplx_kernels as pplx
109
- return self .handle_caches [ 0 ] .get_or_create (
108
+ return self .handle_cache .get_or_create (
110
109
kwargs , pplx .AllToAll .internode
111
110
if self .internode else pplx .AllToAll .intranode )
112
111
113
- def get_handles (self , kwargs ):
114
- import pplx_kernels as pplx
115
- first_handle = self .handle_caches [0 ].get_or_create (
116
- kwargs , pplx .AllToAll .internode
117
- if self .internode else pplx .AllToAll .intranode )
118
- second_handle = self .handle_caches [1 ].get_or_create (
119
- kwargs , pplx .AllToAll .internode
120
- if self .internode else pplx .AllToAll .intranode )
121
- return [first_handle , second_handle ]
122
-
123
112
def dispatch (self , hidden_states : torch .Tensor ,
124
113
router_logits : torch .Tensor ):
125
114
raise NotImplementedError
@@ -128,10 +117,9 @@ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
128
117
raise NotImplementedError
129
118
130
119
def destroy (self ):
131
- for handle_cache in self .handle_caches :
132
- with handle_cache ._lock :
133
- for _ , handle in handle_cache ._cache .items ():
134
- handle .destroy ()
120
+ with self .handle_cache ._lock :
121
+ for _ , handle in self .handle_cache ._cache .items ():
122
+ handle .destroy ()
135
123
136
124
if self .internode :
137
125
from pplx_kernels .nvshmem import nvshmem_finalize
@@ -148,7 +136,7 @@ def __init__(self, cpu_group):
148
136
assert has_deep_ep (
149
137
), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa
150
138
super ().__init__ (cpu_group )
151
- self .handle_caches = [ Cache (), Cache ()]
139
+ self .handle_cache = Cache ()
152
140
153
141
# This is the DeepEP default. Stick to it till we can establish
154
142
# reasonable defaults based on profiling.
@@ -175,7 +163,6 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
175
163
176
164
def __init__ (self , cpu_group ):
177
165
super ().__init__ (cpu_group )
178
- self .handle_cache = self .handle_caches [0 ]
179
166
180
167
def _make_all2all_kwargs (self ) -> dict [Any , Any ]:
181
168
# Defaults for internode and intranode are taken from DeepEP tests.
@@ -224,7 +211,6 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
224
211
225
212
def __init__ (self , cpu_group ):
226
213
super ().__init__ (cpu_group )
227
- self .handle_cache = self .handle_caches [0 ]
228
214
229
215
def _make_all2all_kwargs (
230
216
self ,
@@ -271,8 +257,3 @@ def get_handle(self, kwargs):
271
257
handle : deep_ep .Buffer = self .handle_cache .get_or_create (
272
258
buffer_kwargs , deep_ep .Buffer )
273
259
return handle
274
-
275
- def get_handles (self , kwargs ):
276
- handle = self .get_handle (kwargs )
277
- # For DeepEP we use the same handle for microbatching
278
- return [handle , handle ]
0 commit comments