@@ -763,8 +763,8 @@ def _get_new_notifs(self) -> set[str]:
763763 req_id , tp_ratio = notif .decode ("utf-8" ).rsplit (":" , 1 )
764764 self .consumer_notification_counts_by_req [req_id ] += 1
765765 # Wait all consumers (D) to be done reading before freeing.
766- if self .consumer_notification_counts_by_req [
767- req_id ] == tp_ratio :
766+ if self .consumer_notification_counts_by_req [req_id ] == int (
767+ tp_ratio ) :
768768 notified_req_ids .add (req_id )
769769 del self .consumer_notification_counts_by_req [req_id ]
770770 return notified_req_ids
@@ -929,6 +929,7 @@ def _get_block_descs_ids(self,
929929 If layer_idx is provided, we use the region_ids for the given layer.
930930 Otherwise, we use all regions.
931931 """
932+ # TODO TP docs
932933
933934 if layer_idx is None :
934935 region_ids = range (self .num_regions )
@@ -951,9 +952,9 @@ def _get_block_descs_ids(self,
951952 descs_ids : list [int ] = []
952953 for reg_id in region_ids :
953954 for block_id in block_ids :
954- for kv_block in range (self .block_size ):
955+ for slot_id in range (self .block_size ):
955956 descs_ids .append (reg_id * num_blocks * self .block_size +
956- block_id * self .block_size + kv_block )
957+ block_id * self .block_size + slot_id )
957958 return descs_ids
958959
959960
0 commit comments