11import logging
2- from collections import defaultdict , namedtuple
2+ from collections import defaultdict
33from copy import deepcopy
4+ from typing import Any , Dict , List , NamedTuple , Sequence , Set , Tuple
5+
6+ from aiokafka .structs import TopicPartition
47
58log = logging .getLogger (__name__ )
69
710
8- ConsumerPair = namedtuple ("ConsumerPair" , ["src_member_id" , "dst_member_id" ])
11+ class ConsumerPair (NamedTuple ):
12+ src_member_id : str
13+ dst_member_id : str
14+
15+
916"""
1017Represents a pair of Kafka consumer ids involved in a partition reassignment.
1118Each ConsumerPair corresponds to a particular partition or topic, indicates that the
1623"""
1724
1825
19- def is_sublist (source , target ) :
26+ def is_sublist (source : Sequence [ Any ] , target : Sequence [ Any ]) -> bool :
2027 """Checks if one list is a sublist of another.
2128
2229 Arguments:
@@ -40,11 +47,13 @@ class PartitionMovements:
4047 form a ConsumerPair object) for each partition.
4148 """
4249
43- def __init__ (self ):
44- self .partition_movements_by_topic = defaultdict (lambda : defaultdict (set ))
45- self .partition_movements = {}
50+ def __init__ (self ) -> None :
51+ self .partition_movements_by_topic : Dict [ str , Dict [ ConsumerPair , Set [ TopicPartition ]]] = defaultdict (lambda : defaultdict (set )) # fmt: skip # noqa: E501
52+ self .partition_movements : Dict [ TopicPartition , ConsumerPair ] = {}
4653
47- def move_partition (self , partition , old_consumer , new_consumer ):
54+ def move_partition (
55+ self , partition : TopicPartition , old_consumer : str , new_consumer : str
56+ ) -> None :
4857 pair = ConsumerPair (src_member_id = old_consumer , dst_member_id = new_consumer )
4958 if partition in self .partition_movements :
5059 # this partition has previously moved
@@ -62,7 +71,9 @@ def move_partition(self, partition, old_consumer, new_consumer):
6271 else :
6372 self ._add_partition_movement_record (partition , pair )
6473
65- def get_partition_to_be_moved (self , partition , old_consumer , new_consumer ):
74+ def get_partition_to_be_moved (
75+ self , partition : TopicPartition , old_consumer : str , new_consumer : str
76+ ) -> TopicPartition :
6677 if partition .topic not in self .partition_movements_by_topic :
6778 return partition
6879 if partition in self .partition_movements :
@@ -79,7 +90,7 @@ def get_partition_to_be_moved(self, partition, old_consumer, new_consumer):
7990 iter (self .partition_movements_by_topic [partition .topic ][reverse_pair ])
8091 )
8192
82- def are_sticky (self ):
93+ def are_sticky (self ) -> bool :
8394 for topic , movements in self .partition_movements_by_topic .items ():
8495 movement_pairs = set (movements .keys ())
8596 if self ._has_cycles (movement_pairs ):
@@ -93,7 +104,9 @@ def are_sticky(self):
93104 return False
94105 return True
95106
96- def _remove_movement_record_of_partition (self , partition ):
107+ def _remove_movement_record_of_partition (
108+ self , partition : TopicPartition
109+ ) -> ConsumerPair :
97110 pair = self .partition_movements [partition ]
98111 del self .partition_movements [partition ]
99112
@@ -105,16 +118,18 @@ def _remove_movement_record_of_partition(self, partition):
105118
106119 return pair
107120
108- def _add_partition_movement_record (self , partition , pair ):
121+ def _add_partition_movement_record (
122+ self , partition : TopicPartition , pair : ConsumerPair
123+ ) -> None :
109124 self .partition_movements [partition ] = pair
110125 self .partition_movements_by_topic [partition .topic ][pair ].add (partition )
111126
112- def _has_cycles (self , consumer_pairs ) :
113- cycles = set ()
127+ def _has_cycles (self , consumer_pairs : Set [ ConsumerPair ]) -> bool :
128+ cycles : Set [ Tuple [ str , ...]] = set ()
114129 for pair in consumer_pairs :
115130 reduced_pairs = deepcopy (consumer_pairs )
116131 reduced_pairs .remove (pair )
117- path = [pair .src_member_id ]
132+ path : List [ str ] = [pair .src_member_id ]
118133 if self ._is_linked (
119134 pair .dst_member_id , pair .src_member_id , reduced_pairs , path
120135 ) and not self ._is_subcycle (path , cycles ):
@@ -132,7 +147,7 @@ def _has_cycles(self, consumer_pairs):
132147 )
133148
134149 @staticmethod
135- def _is_subcycle (cycle , cycles ) :
150+ def _is_subcycle (cycle : List [ str ] , cycles : Set [ Tuple [ str , ...]]) -> bool :
136151 super_cycle = deepcopy (cycle )
137152 super_cycle = super_cycle [:- 1 ]
138153 super_cycle .extend (cycle )
@@ -141,7 +156,9 @@ def _is_subcycle(cycle, cycles):
141156 return True
142157 return False
143158
144- def _is_linked (self , src , dst , pairs , current_path ):
159+ def _is_linked (
160+ self , src : str , dst : str , pairs : Set [ConsumerPair ], current_path : List [str ]
161+ ) -> bool :
145162 if src == dst :
146163 return False
147164 if not pairs :
0 commit comments