9393_SAMPLING_EPS = 1e-5
9494
9595
96+ class AsyncD2HMode (IntEnum ):
97+ SYNC = 0
98+ NON_BLOCKING = 1
99+ ASYNC = 2
100+
101+
96102class PhaseType (Enum ):
97103 PREFILL = 'prefill'
98104 PREFIX_PREFILL = 'prefix_prefill'
@@ -802,6 +808,10 @@ def __init__(
802808 # PD
803809 self .kv_conf = self .vllm_config .kv_transfer_config
804810
811+ self .fetch_kv_use_async_d2h = AsyncD2HMode (
812+ envs .VLLM_FETCH_KV_USE_ASYNC_D2H )
813+ logger .info (f"fetch_kv_use_async_d2h: { self .fetch_kv_use_async_d2h } " )
814+
805815 def _set_gc_threshold (self ) -> None :
806816 """
807817 Read https://docs.python.org/3/library/gc.html#gc.set_threshold
@@ -3112,12 +3122,16 @@ def sync_send_kv_caches(hidden_states):
31123122 )
31133123
31143124 def fetch_kv_to_host (model , model_input , kv_caches ,
3115- hidden_states ):
3116- input_tokens_tensor_cpu = model_input .input_tokens .to (
3117- "cpu"
3118- ) # shape: [batch_size, seq_len_padding_to_128]
3119- torch .hpu .synchronize (
3120- ) # sync here may hurt performance.
3125+ hidden_states , async_d2h ):
3126+ if async_d2h in (AsyncD2HMode .NON_BLOCKING ,
3127+ AsyncD2HMode .ASYNC ):
3128+ input_tokens_tensor_cpu = model_input .input_tokens .to (
3129+ "cpu" , non_blocking = True
3130+ )
3131+ else :
3132+ input_tokens_tensor_cpu = model_input .input_tokens .to ("cpu" )
3133+ torch .hpu .synchronize ()
3134+
31213135 seq_lens = model_input .attn_metadata .seq_lens
31223136 start_layer = model .model .start_layer
31233137 end_layer = model .model .end_layer
@@ -3150,22 +3164,37 @@ def fetch_kv_to_host(model, model_input, kv_caches,
31503164 key_cache .index_select (
31513165 0 , current_slot_mapping ).unsqueeze (0 ))
31523166 keys = torch .cat (keys , dim = 0 )
3153- kv_cache_to_sent = keys .cpu ()
3154- current_hidden_states = hidden_states [
3155- idx ].unsqueeze (0 ).cpu ()
3156- # ==== graph should end here ======
3157- htorch .core .mark_step ()
3158- torch .hpu .synchronize ()
3167+ if async_d2h in (AsyncD2HMode .NON_BLOCKING ,
3168+ AsyncD2HMode .ASYNC ):
3169+ kv_cache_to_sent = keys .to ('cpu' , non_blocking = True )
3170+ current_hidden_states = hidden_states [
3171+ idx ].unsqueeze (0 ).to ('cpu' , non_blocking = True )
3172+ else :
3173+ kv_cache_to_sent = keys .cpu ()
3174+ current_hidden_states = hidden_states [
3175+ idx ].unsqueeze (0 ).cpu ()
3176+ htorch .core .mark_step ()
3177+ torch .hpu .synchronize ()
31593178 kv_caches_send_list .append (kv_cache_to_sent )
31603179 hidden_states_list .append (current_hidden_states )
31613180 input_tokens_list .append (current_tokens_cpu )
31623181
3182+ if async_d2h == AsyncD2HMode .ASYNC :
3183+ event = htorch .hpu .Event ()
3184+ event .record ()
3185+ elif async_d2h == AsyncD2HMode .NON_BLOCKING :
3186+ torch .hpu .synchronize ()
3187+ event = None
3188+ else :
3189+ event = None
31633190 return (input_tokens_list , kv_caches_send_list ,
3164- hidden_states_list )
3191+ hidden_states_list , event )
31653192
31663193 def send_kv (input_tokens_list , kv_caches_send_list ,
3167- hidden_states_list ):
3194+ hidden_states_list , event ):
31683195 cur_time = time .time ()
3196+ if self .fetch_kv_use_async_d2h == AsyncD2HMode .ASYNC and event is not None :
3197+ event .synchronize ()
31693198 get_kv_transfer_group (
31703199 ).send_kv_caches_and_hidden_states_cpu (
31713200 input_tokens_list , kv_caches_send_list ,
@@ -3176,17 +3205,19 @@ def send_kv(input_tokens_list, kv_caches_send_list,
31763205 def async_send_kv_caches (hidden_states ):
31773206 (input_tokens_list ,
31783207 kv_caches_send_list ,
3179- hidden_states_list ) = \
3208+ hidden_states_list ,
3209+ event ) = \
31803210 fetch_kv_to_host (self .get_model (),
31813211 model_input ,
31823212 kv_caches ,
31833213 hidden_states ,
3184- )
3214+ self . fetch_kv_use_async_d2h )
31853215 self .pd_executor_pool .submit (
31863216 send_kv ,
31873217 input_tokens_list ,
31883218 kv_caches_send_list ,
31893219 hidden_states_list ,
3220+ event
31903221 )
31913222
31923223 if self .use_async_kv_transfer_in_pd :
0 commit comments