|
17 | 17 | from collections import defaultdict |
18 | 18 | from socket import gethostbyaddr |
19 | 19 | from typing import Any |
| 20 | +from typing import cast |
20 | 21 | from typing import Collection |
21 | 22 | from typing import List |
22 | 23 | from typing import Mapping |
|
32 | 33 | from clusterman.aws.client import InstanceDict |
33 | 34 | from clusterman.aws.markets import get_instance_market |
34 | 35 | from clusterman.aws.markets import InstanceMarket |
| 36 | +from clusterman.aws.markets import MarketDict |
35 | 37 | from clusterman.interfaces.resource_group import InstanceMetadata |
36 | 38 | from clusterman.interfaces.resource_group import ResourceGroup |
37 | 39 |
|
@@ -77,7 +79,7 @@ def get_instance_metadatas(self, state_filter: Optional[Collection[str]] = None) |
77 | 79 | if state_filter and aws_state not in state_filter: |
78 | 80 | continue |
79 | 81 |
|
80 | | - instance_market = get_instance_market(instance_dict) |
| 82 | + instance_market = get_instance_market(cast(MarketDict, instance_dict)) |
81 | 83 | instance_ip = instance_dict.get("PrivateIpAddress") |
82 | 84 | hostname = gethostbyaddr(instance_ip)[0] if instance_ip else None |
83 | 85 | is_cordoned = self._is_instance_cordoned(instance_dict) |
@@ -127,14 +129,16 @@ def terminate_instances_by_id(self, instance_ids: List[str], batch_size: int = 5 |
127 | 129 |
|
128 | 130 | instance_weights = {} |
129 | 131 | for instance in ec2_describe_instances(instance_ids): |
130 | | - instance_market = get_instance_market(instance) |
| 132 | + instance_market = get_instance_market(cast(MarketDict, instance)) |
131 | 133 | if not instance_market.az: |
132 | 134 | logger.warning( |
133 | 135 | f"Instance {instance['InstanceId']} missing AZ info, likely already terminated so skipping", |
134 | 136 | ) |
135 | 137 | instance_ids.remove(instance["InstanceId"]) |
136 | 138 | continue |
137 | | - instance_weights[instance["InstanceId"]] = self.market_weight(get_instance_market(instance)) |
| 139 | + instance_weights[instance["InstanceId"]] = self.market_weight( |
| 140 | + get_instance_market(cast(MarketDict, instance)) |
| 141 | + ) |
138 | 142 |
|
139 | 143 | # AWS API recommends not terminating more than 1000 instances at a time, and to |
140 | 144 | # terminate larger numbers in batches |
@@ -186,15 +190,21 @@ def _get_instances_by_market(self): |
186 | 190 | """Responses from this API call are cached to prevent hitting any AWS request limits""" |
187 | 191 | instance_dict: Mapping[InstanceMarket, List[Mapping]] = defaultdict(list) |
188 | 192 | for instance in ec2_describe_instances(self.instance_ids): |
189 | | - instance_dict[get_instance_market(instance)].append(instance) |
| 193 | + instance_dict[get_instance_market(cast(MarketDict, instance))].append(instance) |
190 | 194 | return instance_dict |
191 | 195 |
|
192 | 196 | @abstractproperty |
193 | 197 | def _target_capacity(self): # pragma: no cover |
194 | 198 | pass |
195 | 199 |
|
196 | 200 | @classmethod |
197 | | - def load(cls, cluster: str, pool: str, config: Any, **kwargs: Any) -> Mapping[str, "AWSResourceGroup"]: |
| 201 | + def load( # type: ignore # (mypy errors with "incompatible signature with supertype") |
| 202 | + cls, |
| 203 | + cluster: str, |
| 204 | + pool: str, |
| 205 | + config: Any, |
| 206 | + **kwargs: Any, |
| 207 | + ) -> Mapping[str, "AWSResourceGroup"]: |
198 | 208 | """Load a list of corresponding resource groups |
199 | 209 |
|
200 | 210 | :param cluster: a cluster name |
|
0 commit comments