@@ -779,23 +779,18 @@ def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]:
779779 """
780780 done_req_ids : set [str ] = set ()
781781 for req_id , handles in list (transfers .items ()):
782- running_reqs = []
783- for handle in handles :
782+ for handle , xfer_stime in handles :
784783 xfer_state = self .nixl_wrapper .check_xfer_state (handle )
785784 if xfer_state == "DONE" :
786785 # TODO ptarasiewicz: why abort is throwing errors?
787786 # self.nixl_wrapper.release_xfer_handle(handle)
787+ done_req_ids .add (req_id )
788+ del transfers [req_id ]
789+ elif xfer_state == "PROC" :
788790 continue
789- if xfer_state == "PROC" :
790- running_reqs .append (handle )
791791 else :
792792 raise RuntimeError ("Transfer failed with state %s" ,
793793 xfer_state )
794- if len (running_reqs ) == 0 :
795- done_req_ids .add (req_id )
796- del transfers [req_id ]
797- else :
798- transfers [req_id ] = running_reqs
799794 return done_req_ids
800795
801796 def start_load_kv (self , metadata : NixlConnectorMetadata ):
@@ -918,7 +913,9 @@ def _read_blocks(
918913 self .nixl_wrapper .transfer (handle )
919914
920915 # Use handle to check completion in future step().
921- self ._recving_transfers [request_id ].append (handle )
916+ # TODO surface xfer elapsed time
917+ self ._recving_transfers [request_id ].append (
918+ (handle , time .perf_counter ()))
922919
923920 def _get_block_descs_ids (self ,
924921 engine_id : str ,
0 commit comments