Skip to content

Commit 10091b7

Browse files
committed
Refactoring the parquet crawler
1 parent 8f5211f commit 10091b7

File tree

10 files changed

+425
-152
lines changed

10 files changed

+425
-152
lines changed

awswrangler/_data_types.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -381,20 +381,6 @@ def athena_types_from_pyarrow_schema(
381381
return columns_types, partitions_types
382382

383383

384-
def athena_partitions_from_pyarrow_partitions(
385-
path: str, partitions: pyarrow.parquet.ParquetPartitions
386-
) -> Dict[str, List[str]]:
387-
"""Extract the related Athena partitions values from any PyArrow Partitions."""
388-
path = path if path[-1] == "/" else f"{path}/"
389-
partitions_values: Dict[str, List[str]] = {}
390-
names: List[str] = [p.name for p in partitions]
391-
for values in zip(*[p.keys for p in partitions]):
392-
suffix: str = "/".join([f"{n}={v}" for n, v in zip(names, values)])
393-
suffix = suffix if suffix[-1] == "/" else f"{suffix}/"
394-
partitions_values[f"{path}{suffix}"] = list(values)
395-
return partitions_values
396-
397-
398384
def cast_pandas_with_athena_types(df: pd.DataFrame, dtype: Dict[str, str]) -> pd.DataFrame:
399385
"""Cast columns in a Pandas DataFrame."""
400386
for col, athena_type in dtype.items():
@@ -412,6 +398,12 @@ def cast_pandas_with_athena_types(df: pd.DataFrame, dtype: Dict[str, str]) -> pd
412398
.astype("string")
413399
.apply(lambda x: Decimal(str(x)) if str(x) not in ("", "none", " ", "<NA>") else None)
414400
)
401+
elif pandas_type == "string":
402+
curr_type: str = str(df[col].dtypes)
403+
if curr_type.startswith("int") or curr_type.startswith("float"):
404+
df[col] = df[col].astype(str).astype("string")
405+
else:
406+
df[col] = df[col].astype("string")
415407
else:
416408
df[col] = df[col].astype(pandas_type)
417409
return df

awswrangler/_utils.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import math
55
import os
6+
import random
67
from typing import Any, Dict, Generator, List, Optional, Tuple
78

89
import boto3 # type: ignore
@@ -11,7 +12,9 @@
1112
import psycopg2 # type: ignore
1213
import s3fs # type: ignore
1314

14-
logger: logging.Logger = logging.getLogger(__name__)
15+
from awswrangler import exceptions
16+
17+
_logger: logging.Logger = logging.getLogger(__name__)
1518

1619

1720
def ensure_session(session: Optional[boto3.Session] = None) -> boto3.Session:
@@ -124,6 +127,8 @@ def chunkify(lst: List[Any], num_chunks: int = 1, max_length: Optional[int] = No
124127
[[0, 1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
125128
126129
"""
130+
if not lst:
131+
return []
127132
n: int = num_chunks if max_length is None else int(math.ceil((float(len(lst)) / float(max_length))))
128133
np_chunks = np.array_split(lst, n)
129134
return [arr.tolist() for arr in np_chunks if len(arr) > 0]
@@ -179,3 +184,52 @@ def get_region_from_subnet(subnet_id: str, boto3_session: Optional[boto3.Session
179184
session: boto3.Session = ensure_session(session=boto3_session)
180185
client_ec2: boto3.client = client(service_name="ec2", session=session)
181186
return client_ec2.describe_subnets(SubnetIds=[subnet_id])["Subnets"][0]["AvailabilityZone"][:9]
187+
188+
189+
def extract_partitions_from_paths(
190+
path: str, paths: List[str]
191+
) -> Tuple[Optional[Dict[str, str]], Optional[Dict[str, List[str]]]]:
192+
"""Extract partitions from Amazon S3 paths."""
193+
path = path if path.endswith("/") else f"{path}/"
194+
partitions_types: Dict[str, str] = {}
195+
partitions_values: Dict[str, List[str]] = {}
196+
for p in paths:
197+
if path not in p:
198+
raise exceptions.InvalidArgumentValue(f"Object {p} is not under the root path ({path}).")
199+
path_wo_filename: str = p.rpartition("/")[0] + "/"
200+
if path_wo_filename not in partitions_values:
201+
path_wo_prefix: str = p.replace(f"{path}/", "")
202+
dirs: List[str] = [x for x in path_wo_prefix.split("/") if (x != "") and ("=" in x)]
203+
if dirs:
204+
values_tups: List[Tuple[str, str]] = [tuple(x.split("=")[:2]) for x in dirs] # type: ignore
205+
values_dics: Dict[str, str] = dict(values_tups)
206+
p_values: List[str] = list(values_dics.values())
207+
p_types: Dict[str, str] = {x: "string" for x in values_dics.keys()}
208+
if not partitions_types:
209+
partitions_types = p_types
210+
if p_values:
211+
partitions_types = p_types
212+
partitions_values[path_wo_filename] = p_values
213+
elif p_types != partitions_types:
214+
raise exceptions.InvalidSchemaConvergence(
215+
f"At least two different partitions schema detected: {partitions_types} and {p_types}"
216+
)
217+
if not partitions_types:
218+
return None, None
219+
return partitions_types, partitions_values
220+
221+
222+
def list_sampling(lst: List[Any], sampling: float) -> List[Any]:
223+
"""Random List sampling."""
224+
if sampling > 1.0 or sampling <= 0.0:
225+
raise exceptions.InvalidArgumentValue(f"Argument <sampling> must be [0.0 < value <= 1.0]. {sampling} received.")
226+
_len: int = len(lst)
227+
if _len == 0:
228+
return []
229+
num_samples: int = int(round(_len * sampling))
230+
num_samples = _len if num_samples > _len else num_samples
231+
num_samples = 1 if num_samples < 1 else num_samples
232+
_logger.debug("_len: %s", _len)
233+
_logger.debug("sampling: %s", sampling)
234+
_logger.debug("num_samples: %s", num_samples)
235+
return random.sample(population=lst, k=num_samples)

awswrangler/athena.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,8 @@ def repair_table(
282282
283283
"""
284284
query = f"MSCK REPAIR TABLE `{table}`;"
285+
if (database is not None) and (not database.startswith("`")):
286+
database = f"`{database}`"
285287
session: boto3.Session = _utils.ensure_session(session=boto3_session)
286288
query_id = start_query_execution(
287289
sql=query,
@@ -492,7 +494,7 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals,too-man
492494
path: str = f"{_s3_output}/{name}"
493495
ext_location: str = "\n" if wg_config["enforced"] is True else f",\n external_location = '{path}'\n"
494496
sql = (
495-
f"CREATE TABLE {name}\n"
497+
f'CREATE TABLE "{name}"\n'
496498
f"WITH(\n"
497499
f" format = 'Parquet',\n"
498500
f" parquet_compression = 'SNAPPY'"

awswrangler/catalog.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,12 @@ def add_parquet_partitions(
240240
... )
241241
242242
"""
243-
inputs: List[Dict[str, Any]] = [
244-
_parquet_partition_definition(location=k, values=v, compression=compression)
245-
for k, v in partitions_values.items()
246-
]
247-
_add_partitions(database=database, table=table, boto3_session=boto3_session, inputs=inputs)
243+
if partitions_values:
244+
inputs: List[Dict[str, Any]] = [
245+
_parquet_partition_definition(location=k, values=v, compression=compression)
246+
for k, v in partitions_values.items()
247+
]
248+
_add_partitions(database=database, table=table, boto3_session=boto3_session, inputs=inputs)
248249

249250

250251
def _parquet_partition_definition(location: str, values: List[str], compression: Optional[str]) -> Dict[str, Any]:

awswrangler/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,7 @@ class InvalidRedshiftSortkey(Exception):
8383

8484
class InvalidRedshiftPrimaryKeys(Exception):
8585
"""InvalidRedshiftPrimaryKeys exception."""
86+
87+
88+
class InvalidSchemaConvergence(Exception):
89+
"""InvalidSchemaMerge exception."""

0 commit comments

Comments
 (0)