@@ -773,8 +773,8 @@ def _get_new_notifs(self) -> set[str]:
773773 req_id , tp_ratio = notif .decode ("utf-8" ).rsplit (":" , 1 )
774774 self .consumer_notification_counts_by_req [req_id ] += 1
775775 # Wait all consumers (D) to be done reading before freeing.
776- if self .consumer_notification_counts_by_req [
777- req_id ] == tp_ratio :
776+ if self .consumer_notification_counts_by_req [req_id ] == int (
777+ tp_ratio ) :
778778 notified_req_ids .add (req_id )
779779 del self .consumer_notification_counts_by_req [req_id ]
780780 return notified_req_ids
@@ -904,18 +904,20 @@ def _read_blocks(
904904 def _get_block_descs_ids (self , engine_id : str ,
905905 block_ids : list [int ]) -> list [int ]:
906906 """Get the descs ids for a set of block ids."""
907+ # TODO docs
907908
908909 # range(1) for MLA, range(2) otherwise.
909910 region_ids = range (self .num_regions )
911+ # TODO using a diff num of blocks here in dst and src
910912 num_blocks = self .dst_num_blocks [engine_id ]
911913
912914 # Compute the desc ids for each block.
913915 descs_ids : list [int ] = []
914916 for reg_id in region_ids :
915917 for block_id in block_ids :
916- for kv_block in range (self .block_size ):
918+ for slot_id in range (self .block_size ):
917919 descs_ids .append (reg_id * num_blocks * self .block_size +
918- block_id * self .block_size + kv_block )
920+ block_id * self .block_size + slot_id )
919921 return descs_ids
920922
921923
0 commit comments