Skip to content

Commit f8f3b68

Browse files
authored
[P/D] Implement Non-blocking/Async D2H optimization for kv fetching a… (#2024)
…t prefill The NON_BLOCKING mode is functionally ready while the ASYNC mode has piece missing and is right now just for dev use
1 parent 7408e3a commit f8f3b68

File tree

2 files changed

+53
-16
lines changed

2 files changed

+53
-16
lines changed

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
VLLM_USE_ASYNC_TRANSFER_IN_PD: bool = False
100100
VLLM_SKIP_PREFILL_SAMPLING: bool = False
101101
VLLM_ABORT_REQUEST_KV_CACHE_MISS: bool = True
102+
VLLM_FETCH_KV_USE_ASYNC_D2H: int = 0
102103

103104

104105
def get_default_cache_root():
@@ -644,6 +645,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
644645
lambda: bool(int(os.getenv("VLLM_SKIP_PREFILL_SAMPLING", "0"))),
645646
"VLLM_ABORT_REQUEST_KV_CACHE_MISS":
646647
lambda: bool(int(os.getenv("VLLM_ABORT_REQUEST_KV_CACHE_MISS", "1"))),
648+
649+
# Controls the async device-to-host mode used when fetching KV caches.
650+
# 0: Sync copy, 1: Non-blocking copy, 2: Async copy with event
651+
"VLLM_FETCH_KV_USE_ASYNC_D2H":
652+
lambda: int(os.getenv("VLLM_FETCH_KV_USE_ASYNC_D2H", "0")),
647653
}
648654

649655
# end-env-vars-definition

vllm/worker/hpu_model_runner.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@
9393
_SAMPLING_EPS = 1e-5
9494

9595

96+
class AsyncD2HMode(IntEnum):
97+
SYNC = 0
98+
NON_BLOCKING = 1
99+
ASYNC = 2
100+
101+
96102
class PhaseType(Enum):
97103
PREFILL = 'prefill'
98104
PREFIX_PREFILL = 'prefix_prefill'
@@ -802,6 +808,10 @@ def __init__(
802808
# PD
803809
self.kv_conf = self.vllm_config.kv_transfer_config
804810

811+
self.fetch_kv_use_async_d2h = AsyncD2HMode(
812+
envs.VLLM_FETCH_KV_USE_ASYNC_D2H)
813+
logger.info(f"fetch_kv_use_async_d2h: {self.fetch_kv_use_async_d2h}")
814+
805815
def _set_gc_threshold(self) -> None:
806816
"""
807817
Read https://docs.python.org/3/library/gc.html#gc.set_threshold
@@ -3112,12 +3122,16 @@ def sync_send_kv_caches(hidden_states):
31123122
)
31133123

31143124
def fetch_kv_to_host(model, model_input, kv_caches,
3115-
hidden_states):
3116-
input_tokens_tensor_cpu = model_input.input_tokens.to(
3117-
"cpu"
3118-
) # shape: [batch_size, seq_len_padding_to_128]
3119-
torch.hpu.synchronize(
3120-
) # sync here may hurt performance.
3125+
hidden_states, async_d2h):
3126+
if async_d2h in (AsyncD2HMode.NON_BLOCKING,
3127+
AsyncD2HMode.ASYNC):
3128+
input_tokens_tensor_cpu = model_input.input_tokens.to(
3129+
"cpu", non_blocking=True
3130+
)
3131+
else:
3132+
input_tokens_tensor_cpu = model_input.input_tokens.to("cpu")
3133+
torch.hpu.synchronize()
3134+
31213135
seq_lens = model_input.attn_metadata.seq_lens
31223136
start_layer = model.model.start_layer
31233137
end_layer = model.model.end_layer
@@ -3150,22 +3164,37 @@ def fetch_kv_to_host(model, model_input, kv_caches,
31503164
key_cache.index_select(
31513165
0, current_slot_mapping).unsqueeze(0))
31523166
keys = torch.cat(keys, dim=0)
3153-
kv_cache_to_sent = keys.cpu()
3154-
current_hidden_states = hidden_states[
3155-
idx].unsqueeze(0).cpu()
3156-
# ==== graph should end here ======
3157-
htorch.core.mark_step()
3158-
torch.hpu.synchronize()
3167+
if async_d2h in (AsyncD2HMode.NON_BLOCKING,
3168+
AsyncD2HMode.ASYNC):
3169+
kv_cache_to_sent = keys.to('cpu', non_blocking=True)
3170+
current_hidden_states = hidden_states[
3171+
idx].unsqueeze(0).to('cpu', non_blocking=True)
3172+
else:
3173+
kv_cache_to_sent = keys.cpu()
3174+
current_hidden_states = hidden_states[
3175+
idx].unsqueeze(0).cpu()
3176+
htorch.core.mark_step()
3177+
torch.hpu.synchronize()
31593178
kv_caches_send_list.append(kv_cache_to_sent)
31603179
hidden_states_list.append(current_hidden_states)
31613180
input_tokens_list.append(current_tokens_cpu)
31623181

3182+
if async_d2h == AsyncD2HMode.ASYNC:
3183+
event = htorch.hpu.Event()
3184+
event.record()
3185+
elif async_d2h == AsyncD2HMode.NON_BLOCKING:
3186+
torch.hpu.synchronize()
3187+
event = None
3188+
else:
3189+
event = None
31633190
return (input_tokens_list, kv_caches_send_list,
3164-
hidden_states_list)
3191+
hidden_states_list, event)
31653192

31663193
def send_kv(input_tokens_list, kv_caches_send_list,
3167-
hidden_states_list):
3194+
hidden_states_list, event):
31683195
cur_time = time.time()
3196+
if self.fetch_kv_use_async_d2h == AsyncD2HMode.ASYNC and event is not None:
3197+
event.synchronize()
31693198
get_kv_transfer_group(
31703199
).send_kv_caches_and_hidden_states_cpu(
31713200
input_tokens_list, kv_caches_send_list,
@@ -3176,17 +3205,19 @@ def send_kv(input_tokens_list, kv_caches_send_list,
31763205
def async_send_kv_caches(hidden_states):
31773206
(input_tokens_list,
31783207
kv_caches_send_list,
3179-
hidden_states_list) = \
3208+
hidden_states_list,
3209+
event) = \
31803210
fetch_kv_to_host(self.get_model(),
31813211
model_input,
31823212
kv_caches,
31833213
hidden_states,
3184-
)
3214+
self.fetch_kv_use_async_d2h)
31853215
self.pd_executor_pool.submit(
31863216
send_kv,
31873217
input_tokens_list,
31883218
kv_caches_send_list,
31893219
hidden_states_list,
3220+
event
31903221
)
31913222

31923223
if self.use_async_kv_transfer_in_pd:

0 commit comments

Comments
 (0)