|
3 | 3 | import logging |
4 | 4 | import math |
5 | 5 | import os |
| 6 | +import random |
6 | 7 | from typing import Any, Dict, Generator, List, Optional, Tuple |
7 | 8 |
|
8 | 9 | import boto3 # type: ignore |
|
11 | 12 | import psycopg2 # type: ignore |
12 | 13 | import s3fs # type: ignore |
13 | 14 |
|
14 | | -logger: logging.Logger = logging.getLogger(__name__) |
| 15 | +from awswrangler import exceptions |
| 16 | + |
| 17 | +_logger: logging.Logger = logging.getLogger(__name__) |
15 | 18 |
|
16 | 19 |
|
17 | 20 | def ensure_session(session: Optional[boto3.Session] = None) -> boto3.Session: |
18 | 21 | """Ensure that a valid boto3.Session will be returned.""" |
19 | 22 | if session is not None: |
20 | 23 | return session |
21 | | - return boto3.Session() |
| 24 | + # Ensure the boto3's default session is used so that its parameters can be |
| 25 | + # set via boto3.setup_default_session() |
| 26 | + if boto3.DEFAULT_SESSION is not None: |
| 27 | + return boto3.DEFAULT_SESSION |
| 28 | + return boto3.Session() # pragma: no cover |
22 | 29 |
|
23 | 30 |
|
24 | 31 | def client(service_name: str, session: Optional[boto3.Session] = None) -> boto3.client: |
@@ -124,6 +131,8 @@ def chunkify(lst: List[Any], num_chunks: int = 1, max_length: Optional[int] = No |
124 | 131 | [[0, 1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]] |
125 | 132 |
|
126 | 133 | """ |
| 134 | + if not lst: |
| 135 | + return [] # pragma: no cover |
127 | 136 | n: int = num_chunks if max_length is None else int(math.ceil((float(len(lst)) / float(max_length)))) |
128 | 137 | np_chunks = np.array_split(lst, n) |
129 | 138 | return [arr.tolist() for arr in np_chunks if len(arr) > 0] |
@@ -179,3 +188,54 @@ def get_region_from_subnet(subnet_id: str, boto3_session: Optional[boto3.Session |
179 | 188 | session: boto3.Session = ensure_session(session=boto3_session) |
180 | 189 | client_ec2: boto3.client = client(service_name="ec2", session=session) |
181 | 190 | return client_ec2.describe_subnets(SubnetIds=[subnet_id])["Subnets"][0]["AvailabilityZone"][:9] |
| 191 | + |
| 192 | + |
| 193 | +def extract_partitions_from_paths( |
| 194 | + path: str, paths: List[str] |
| 195 | +) -> Tuple[Optional[Dict[str, str]], Optional[Dict[str, List[str]]]]: |
| 196 | + """Extract partitions from Amazon S3 paths.""" |
| 197 | + path = path if path.endswith("/") else f"{path}/" |
| 198 | + partitions_types: Dict[str, str] = {} |
| 199 | + partitions_values: Dict[str, List[str]] = {} |
| 200 | + for p in paths: |
| 201 | + if path not in p: |
| 202 | + raise exceptions.InvalidArgumentValue( |
| 203 | + f"Object {p} is not under the root path ({path})." |
| 204 | + ) # pragma: no cover |
| 205 | + path_wo_filename: str = p.rpartition("/")[0] + "/" |
| 206 | + if path_wo_filename not in partitions_values: |
| 207 | + path_wo_prefix: str = p.replace(f"{path}/", "") |
| 208 | + dirs: List[str] = [x for x in path_wo_prefix.split("/") if (x != "") and ("=" in x)] |
| 209 | + if dirs: |
| 210 | + values_tups: List[Tuple[str, str]] = [tuple(x.split("=")[:2]) for x in dirs] # type: ignore |
| 211 | + values_dics: Dict[str, str] = dict(values_tups) |
| 212 | + p_values: List[str] = list(values_dics.values()) |
| 213 | + p_types: Dict[str, str] = {x: "string" for x in values_dics.keys()} |
| 214 | + if not partitions_types: |
| 215 | + partitions_types = p_types |
| 216 | + if p_values: |
| 217 | + partitions_types = p_types |
| 218 | + partitions_values[path_wo_filename] = p_values |
| 219 | + elif p_types != partitions_types: # pragma: no cover |
| 220 | + raise exceptions.InvalidSchemaConvergence( |
| 221 | + f"At least two different partitions schema detected: {partitions_types} and {p_types}" |
| 222 | + ) |
| 223 | + if not partitions_types: |
| 224 | + return None, None |
| 225 | + return partitions_types, partitions_values |
| 226 | + |
| 227 | + |
| 228 | +def list_sampling(lst: List[Any], sampling: float) -> List[Any]: |
| 229 | + """Random List sampling.""" |
| 230 | + if sampling > 1.0 or sampling <= 0.0: # pragma: no cover |
| 231 | + raise exceptions.InvalidArgumentValue(f"Argument <sampling> must be [0.0 < value <= 1.0]. {sampling} received.") |
| 232 | + _len: int = len(lst) |
| 233 | + if _len == 0: |
| 234 | + return [] # pragma: no cover |
| 235 | + num_samples: int = int(round(_len * sampling)) |
| 236 | + num_samples = _len if num_samples > _len else num_samples |
| 237 | + num_samples = 1 if num_samples < 1 else num_samples |
| 238 | + _logger.debug("_len: %s", _len) |
| 239 | + _logger.debug("sampling: %s", sampling) |
| 240 | + _logger.debug("num_samples: %s", num_samples) |
| 241 | + return random.sample(population=lst, k=num_samples) |
0 commit comments