Skip to content

Commit 97d553d

Browse files
authored
<PD> refine send on rank0 only logic (#2043)
Prefill DPxTP can be enabled upon this change
1 parent f8f3b68 commit 97d553d

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -241,20 +241,22 @@ def send_kv_caches_and_hidden_states_cpu(
241241
hidden_states_list: List[torch.Tensor],
242242
) -> None:
243243
start_time = time.time()
244-
if self.rank != 0:
244+
if self.rank % self.tp_size != 0:
245245
# only the first rank will send kv cache
246246
return
247247
assert len(input_tokens_list) == len(kv_caches_send_list)
248248
assert len(input_tokens_list) == len(hidden_states_list)
249249
for idx, input_tokens in enumerate(input_tokens_list):
250250
store_key_prefix = self.tensor_hash(input_tokens)
251-
store_kvcache_key = f"{store_key_prefix}_{self.rank}"
252-
store_hidden_key = f"{store_key_prefix}_hidden_{self.rank}"
251+
store_kvcache_key = f"{store_key_prefix}_{self.rank % self.tp_size}"
252+
store_hidden_key = f"{store_key_prefix}_hidden_{self.rank % self.tp_size}"
253253

254254
self.kv_store.put_tensor(store_kvcache_key,
255255
kv_caches_send_list[idx])
256256
self.kv_store.put_tensor(store_hidden_key, hidden_states_list[idx])
257-
logger.info("[rank %d]: KV send DONE. send %d, takes %f s", self.rank,
257+
logger.info("[rank %d][tp size %d]:KV send DONE. send %d, takes %f s",
258+
self.rank,
259+
self.tp_size,
258260
len(input_tokens_list),
259261
time.time() - start_time)
260262

@@ -266,7 +268,7 @@ def send_kv_caches_and_hidden_states_hpu(
266268
hidden_or_intermediate_states: Union[torch.Tensor,
267269
IntermediateTensors],
268270
) -> None:
269-
if self.rank != 0:
271+
if self.rank % self.tp_size != 0:
270272
# only the first rank will send kv cache
271273
return
272274
start_time = time.time()
@@ -312,12 +314,12 @@ def send_kv_caches_and_hidden_states_hpu(
312314
keys = torch.cat(keys, dim=0)
313315
kvcache_to_sent = keys.cpu()
314316
logger.debug("kv cache reshape time: %s", time.time() - start_time)
315-
store_kvcache_key = f"{store_key_prefix}_{self.rank}"
317+
store_kvcache_key = f"{store_key_prefix}_{self.rank % self.tp_size}"
316318
self.kv_store.put_tensor(store_kvcache_key, kvcache_to_sent)
317319

318320
logger.debug("put kv cache key: %s", store_kvcache_key)
319321

320-
hidden_key = f"{store_key_prefix}_hidden_{self.rank}"
322+
hidden_key = f"{store_key_prefix}_hidden_{self.rank % self.tp_size}"
321323
self.kv_store.put_tensor(
322324
hidden_key,
323325
hidden_or_intermediate_states[idx].unsqueeze(0).cpu())

0 commit comments

Comments
 (0)