@@ -799,7 +799,7 @@ def get_finished(self, scheduler_output) -> tuple[set[str], set[str]]:
799799 to Rank 0 once their transaction is done + Rank 0 returns
800800 finished sets to Scheduler only once all ranks are done.
801801 """
802- logger .debug (f'buke: get_finished: |{ scheduler_output = } |{ vars (scheduler_output .kv_connector_metadata )= } ' )
802+ logger .debug (f'buke: get_finished: { self . engine_id = } |{ scheduler_output = } |{ vars (scheduler_output .kv_connector_metadata )= } ' )
803803 #print(f"buke: get_finished: {self.kv_caches_hpu['model.layers.0.self_attn.attn'][0].data_ptr()=}|{self.kv_caches_hpu['model.layers.0.self_attn.attn'][1].data_ptr()=}|")
804804 k00 ,v00 = self .kv_caches_hpu ['model.layers.0.self_attn.attn' ]
805805 #print(f'buke: get_finished hpu: {k00.shape=}|{k00.sum(dim=[1,2])[100:400]=}', flush=True)
@@ -810,9 +810,10 @@ def get_finished(self, scheduler_output) -> tuple[set[str], set[str]]:
810810 done_sending = self ._get_new_notifs ()
811811 done_recving = self ._pop_done_transfers (self ._recving_transfers )
812812 requests = scheduler_output .kv_connector_metadata .requests
813- logger .debug (f'buke get_finished: { self ._transfering_req_meta = } ' )
814-
815- #logger.debug(f'buke: get_finished: {done_sending=}|{done_recving=}|{requests=}')
813+ logger .debug (f'buke get_finished: { self ._transfering_req_meta = } ' )
814+ logger .debug (f'buke: get_finished: { done_sending = } |{ done_recving = } |{ requests = } ' )
815+ if len (requests ) == 0 :
816+ done_sending = set ()
816817 if len (done_sending ) > 0 or len (done_recving ) > 0 :
817818 k00 ,v00 = self .kv_caches_cpu ['model.layers.0.self_attn.attn' ]
818819 #logger.debug(f'buke: get_finished cpu: {k00.shape=}|{k00.sum(dim=[1,2])[100:400]=}')
@@ -834,7 +835,6 @@ def get_finished(self, scheduler_output) -> tuple[set[str], set[str]]:
834835 self .kv_caches_hpu [layer ][1 ][start :end ].copy_ (v [start :end ], non_blocking = False )
835836 k00 ,v00 = self .kv_caches_hpu ['model.layers.0.self_attn.attn' ]
836837 del self ._transfering_req_meta [req ]
837-
838838 logger .debug (f'buke: get_finished hpu: { k00 .shape = } |{ k00 .sum (dim = [1 ,2 ])[100 :400 ]= } ' )
839839 logger .debug (
840840 "Rank %s, get_finished: %s requests done sending "
@@ -895,6 +895,7 @@ def _get_new_notifs(self) -> set[str]:
895895 notified_req_ids : set [str ] = set ()
896896 for notifs in self .nixl_wrapper .get_new_notifs ().values ():
897897 for notif in notifs :
898+ logger .debug (f'buke _get_new_notifs: { notif .decode ("utf-8" )= } ' )
898899 req_id , tp_ratio = notif .decode ("utf-8" ).rsplit (":" , 1 )
899900 self .consumer_notification_counts_by_req [req_id ] += 1
900901 # Wait all consumers (D) to be done reading before freeing.
@@ -919,6 +920,8 @@ def _pop_done_transfers(
919920 for handle , xfer_stime in handles :
920921 xfer_state = self .nixl_wrapper .check_xfer_state (handle )
921922 if xfer_state == "DONE" :
923+ xfer_end_time = time .perf_counter ()
924+ logger .debug (f"buke _pop_done_transfers: { req_id = } |{ handle = } |{ xfer_end_time = } |{ xfer_end_time - xfer_stime = } " )
922925 self .nixl_wrapper .release_xfer_handle (handle )
923926 done_req_ids .add (req_id )
924927 del transfers [req_id ]
@@ -934,7 +937,8 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
934937 Start loading by triggering non-blocking nixl_xfer.
935938 We check for these trnxs to complete in each step().
936939 """
937- #print(f'buke start_load_kv: {len(metadata.requests)=}')
940+ #logger.debug(f'buke start_load_kv: {self._get_new_notifs()=}')
941+ logger .debug (f'buke start_load_kv: { len (metadata .requests )= } |{ self .engine_id = } ' )
938942 for req_id , meta in metadata .requests .items ():
939943 logger .debug (
940944 "start_load_kv for request %s from remote engine %s. "
@@ -997,6 +1001,7 @@ def _read_blocks(
9971001 remote_rank = self .tp_rank // tp_ratio
9981002 agent_name = self ._remote_agents [dst_engine_id ][remote_rank ]
9991003 self .nixl_wrapper .send_notif (agent_name , notif_msg = notif_id )
1004+ #logger.debug(f'buke send_notif: {agent_name=}|{notif_id=}')
10001005 return
10011006
10021007 # Partial prefix cache hit: just read uncomputed blocks.
@@ -1058,10 +1063,11 @@ def _read_blocks(
10581063 remote_block_descs_ids ,
10591064 notif_msg = notif_id ,
10601065 )
1061- logger .debug (f'buke: >>>>> real transfer start >>>>> { remote_block_descs_ids = } |{ local_block_descs_ids = } ' )
1066+ logger .debug (f'buke: >>>>> real transfer start >>>>> { remote_block_descs_ids = } |{ local_block_descs_ids = } | { notif_id = } | { time . perf_counter () = } ' )
10621067 # Begin async xfer.
1068+ time_before_async_trans = time .perf_counter ()
10631069 self .nixl_wrapper .transfer (handle )
1064-
1070+ logger . debug ( f'buke: transfer { time . perf_counter () - time_before_async_trans = } ' )
10651071 # Use handle to check completion in future step().
10661072 # TODO (NickLucche) surface xfer elapsed time
10671073 self ._recving_transfers [request_id ].append (
0 commit comments