@@ -789,23 +789,18 @@ def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]:
789789 """
790790 done_req_ids : set [str ] = set ()
791791 for req_id , handles in list (transfers .items ()):
792- running_reqs = []
793- for handle in handles :
792+ for handle , xfer_stime in handles :
794793 xfer_state = self .nixl_wrapper .check_xfer_state (handle )
795794 if xfer_state == "DONE" :
796795 # TODO ptarasiewicz: why abort is throwing errors?
797796 # self.nixl_wrapper.release_xfer_handle(handle)
797+ done_req_ids .add (req_id )
798+ del transfers [req_id ]
799+ elif xfer_state == "PROC" :
798800 continue
799- if xfer_state == "PROC" :
800- running_reqs .append (handle )
801801 else :
802802 raise RuntimeError ("Transfer failed with state %s" ,
803803 xfer_state )
804- if len (running_reqs ) == 0 :
805- done_req_ids .add (req_id )
806- del transfers [req_id ]
807- else :
808- transfers [req_id ] = running_reqs
809804 return done_req_ids
810805
811806 def start_load_kv (self , metadata : NixlConnectorMetadata ):
@@ -899,7 +894,9 @@ def _read_blocks(
899894 self .nixl_wrapper .transfer (handle )
900895
901896 # Use handle to check completion in future step().
902- self ._recving_transfers [request_id ].append (handle )
897+ # TODO surface xfer elapsed time
898+ self ._recving_transfers [request_id ].append (
899+ (handle , time .perf_counter ()))
903900
904901 def _get_block_descs_ids (self , engine_id : str ,
905902 block_ids : list [int ]) -> list [int ]:
0 commit comments