Skip to content

Commit 541d5fb

Browse files
committed
Merge branch 'master' of github.com:awslabs/aws-data-wrangler
2 parents 65f865a + cf0690b commit 541d5fb

File tree

8 files changed

+442
-83
lines changed

8 files changed

+442
-83
lines changed

awswrangler/athena.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def _extract_ctas_manifest_paths(path: str, boto3_session: Optional[boto3.Sessio
259259

260260

261261
def _get_query_metadata(
262-
query_execution_id: str, boto3_session: Optional[boto3.Session] = None
262+
query_execution_id: str, categories: List[str] = None, boto3_session: Optional[boto3.Session] = None
263263
) -> Tuple[Dict[str, str], List[str], List[str], Dict[str, Any], List[str]]:
264264
"""Get query metadata."""
265265
cols_types: Dict[str, str] = get_query_columns_types(
@@ -285,7 +285,9 @@ def _get_query_metadata(
285285
"Please use ctas_approach=True for Struct columns."
286286
)
287287
pandas_type: str = _data_types.athena2pandas(dtype=col_type)
288-
if pandas_type in ["datetime64", "date"]:
288+
if (categories is not None) and (col_name in categories):
289+
dtype[col_name] = "category"
290+
elif pandas_type in ["datetime64", "date"]:
289291
parse_timestamps.append(col_name)
290292
if pandas_type == "date":
291293
parse_dates.append(col_name)
@@ -326,6 +328,7 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
326328
sql: str,
327329
database: str,
328330
ctas_approach: bool = True,
331+
categories: List[str] = None,
329332
chunksize: Optional[int] = None,
330333
s3_output: Optional[str] = None,
331334
workgroup: Optional[str] = None,
@@ -377,6 +380,9 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
377380
ctas_approach: bool
378381
Wraps the query using a CTAS, and read the resulted parquet data on S3.
379382
If false, read the regular CSV on S3.
383+
categories: List[str], optional
384+
List of columns names that should be returned as pandas.Categorical.
385+
Recommended for memory restricted environments.
380386
chunksize: int, optional
381387
If specified, return an generator where chunksize is the number of rows to include in each chunk.
382388
s3_output : str, optional
@@ -457,10 +463,12 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
457463
dfs = _utils.empty_generator()
458464
else:
459465
s3.wait_objects_exist(paths=paths, use_threads=False, boto3_session=session)
460-
dfs = s3.read_parquet(path=paths, use_threads=use_threads, boto3_session=session, chunked=chunked)
466+
dfs = s3.read_parquet(
467+
path=paths, use_threads=use_threads, boto3_session=session, chunked=chunked, categories=categories
468+
)
461469
return dfs
462470
dtype, parse_timestamps, parse_dates, converters, binaries = _get_query_metadata(
463-
query_execution_id=query_id, boto3_session=session
471+
query_execution_id=query_id, categories=categories, boto3_session=session
464472
)
465473
path = f"{_s3_output}{query_id}.csv"
466474
s3.wait_objects_exist(paths=[path], use_threads=False, boto3_session=session)
@@ -539,12 +547,13 @@ def get_work_group(workgroup: str, boto3_session: Optional[boto3.Session] = None
539547
def _ensure_workgroup(
540548
session: boto3.Session, workgroup: Optional[str] = None
541549
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
542-
if workgroup:
550+
if workgroup is not None:
543551
res: Dict[str, Any] = get_work_group(workgroup=workgroup, boto3_session=session)
544552
config: Dict[str, Any] = res["WorkGroup"]["Configuration"]["ResultConfiguration"]
545553
wg_s3_output: Optional[str] = config.get("OutputLocation")
546-
wg_encryption: Optional[str] = config["EncryptionConfiguration"].get("EncryptionOption")
547-
wg_kms_key: Optional[str] = config["EncryptionConfiguration"].get("KmsKey")
554+
encrypt_config: Optional[Dict[str, str]] = config.get("EncryptionConfiguration")
555+
wg_encryption: Optional[str] = None if encrypt_config is None else encrypt_config.get("EncryptionOption")
556+
wg_kms_key: Optional[str] = None if encrypt_config is None else encrypt_config.get("KmsKey")
548557
else:
549558
wg_s3_output, wg_encryption, wg_kms_key = None, None, None
550559
return wg_s3_output, wg_encryption, wg_kms_key

awswrangler/catalog.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -730,26 +730,21 @@ def table(
730730

731731

732732
def _sanitize_name(name: str) -> str:
733-
name = "".join(c for c in unicodedata.normalize("NFD", name) if unicodedata.category(c) != "Mn")
734-
name = name.replace("{", "_")
735-
name = name.replace("}", "_")
736-
name = name.replace("]", "_")
737-
name = name.replace("[", "_")
738-
name = name.replace(")", "_")
739-
name = name.replace("(", "_")
740-
name = name.replace(" ", "_")
741-
name = name.replace("-", "_")
742-
name = name.replace(".", "_")
743-
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
744-
name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name)
745-
return name.lower()
733+
name = "".join(c for c in unicodedata.normalize("NFD", name) if unicodedata.category(c) != "Mn") # strip accents
734+
name = re.sub("[^A-Za-z0-9_]+", "_", name) # Removing non alphanumeric characters
735+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() # Converting CamelCase to snake_case
746736

747737

748738
def sanitize_column_name(column: str) -> str:
749739
"""Convert the column name to be compatible with Amazon Athena.
750740
751741
https://docs.aws.amazon.com/athena/latest/ug/tables-databases-columns-names.html
752742
743+
Possible transformations:
744+
- Strip accents
745+
- Remove non alphanumeric characters
746+
- Convert CamelCase to snake_case
747+
753748
Parameters
754749
----------
755750
column : str
@@ -775,6 +770,11 @@ def sanitize_dataframe_columns_names(df: pd.DataFrame) -> pd.DataFrame:
775770
776771
https://docs.aws.amazon.com/athena/latest/ug/tables-databases-columns-names.html
777772
773+
Possible transformations:
774+
- Strip accents
775+
- Remove non alphanumeric characters
776+
- Convert CamelCase to snake_case
777+
778778
Parameters
779779
----------
780780
df : pandas.DataFrame
@@ -800,6 +800,11 @@ def sanitize_table_name(table: str) -> str:
800800
801801
https://docs.aws.amazon.com/athena/latest/ug/tables-databases-columns-names.html
802802
803+
Possible transformations:
804+
- Strip accents
805+
- Remove non alphanumeric characters
806+
- Convert CamelCase to snake_case
807+
803808
Parameters
804809
----------
805810
table : str

awswrangler/db.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,7 @@ def unload_redshift(
887887
path: str,
888888
con: sqlalchemy.engine.Engine,
889889
iam_role: str,
890+
categories: List[str] = None,
890891
chunked: bool = False,
891892
keep_files: bool = False,
892893
use_threads: bool = True,
@@ -920,6 +921,9 @@ def unload_redshift(
920921
wr.db.get_engine(), wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()
921922
iam_role : str
922923
AWS IAM role with the related permissions.
924+
categories: List[str], optional
925+
List of columns names that should be returned as pandas.Categorical.
926+
Recommended for memory restricted environments.
923927
keep_files : bool
924928
Should keep the stage files?
925929
chunked : bool
@@ -960,6 +964,7 @@ def unload_redshift(
960964
return pd.DataFrame()
961965
df: pd.DataFrame = s3.read_parquet(
962966
path=paths,
967+
categories=categories,
963968
chunked=chunked,
964969
dataset=False,
965970
use_threads=use_threads,
@@ -973,6 +978,7 @@ def unload_redshift(
973978
return _utils.empty_generator()
974979
return _read_parquet_iterator(
975980
paths=paths,
981+
categories=categories,
976982
use_threads=use_threads,
977983
boto3_session=session,
978984
s3_additional_kwargs=s3_additional_kwargs,
@@ -984,11 +990,13 @@ def _read_parquet_iterator(
984990
paths: List[str],
985991
keep_files: bool,
986992
use_threads: bool,
993+
categories: List[str] = None,
987994
boto3_session: Optional[boto3.Session] = None,
988995
s3_additional_kwargs: Optional[Dict[str, str]] = None,
989996
) -> Iterator[pd.DataFrame]:
990997
dfs: Iterator[pd.DataFrame] = s3.read_parquet(
991998
path=paths,
999+
categories=categories,
9921000
chunked=True,
9931001
dataset=False,
9941002
use_threads=use_threads,

awswrangler/s3.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,11 @@ def to_parquet( # pylint: disable=too-many-arguments
530530
The concept of Dataset goes beyond the simple idea of files and enable more
531531
complex features like partitioning, casting and catalog integration (Amazon Athena/AWS Glue Catalog).
532532
533+
Note
534+
----
535+
The table name and all column names will be automatically sanitize using
536+
`wr.catalog.sanitize_table_name` and `wr.catalog.sanitize_column_name`.
537+
533538
Note
534539
----
535540
In case of `use_threads=True` the number of process that will be spawned will be get from os.cpu_count().
@@ -833,7 +838,7 @@ def _to_parquet_file(
833838
fs: s3fs.S3FileSystem,
834839
dtype: Dict[str, str],
835840
) -> str:
836-
table: pa.Table = pyarrow.Table.from_pandas(df=df, schema=schema, nthreads=cpus, preserve_index=index, safe=False)
841+
table: pa.Table = pyarrow.Table.from_pandas(df=df, schema=schema, nthreads=cpus, preserve_index=index, safe=True)
837842
for col_name, col_type in dtype.items():
838843
if col_name in table.column_names:
839844
col_index = table.column_names.index(col_name)
@@ -1190,6 +1195,7 @@ def _read_text_full(
11901195
def _read_parquet_init(
11911196
path: Union[str, List[str]],
11921197
filters: Optional[Union[List[Tuple], List[List[Tuple]]]] = None,
1198+
categories: List[str] = None,
11931199
dataset: bool = False,
11941200
use_threads: bool = True,
11951201
boto3_session: Optional[boto3.Session] = None,
@@ -1206,7 +1212,7 @@ def _read_parquet_init(
12061212
fs: s3fs.S3FileSystem = _utils.get_fs(session=boto3_session, s3_additional_kwargs=s3_additional_kwargs)
12071213
cpus: int = _utils.ensure_cpu_count(use_threads=use_threads)
12081214
data: pyarrow.parquet.ParquetDataset = pyarrow.parquet.ParquetDataset(
1209-
path_or_paths=path_or_paths, filesystem=fs, metadata_nthreads=cpus, filters=filters
1215+
path_or_paths=path_or_paths, filesystem=fs, metadata_nthreads=cpus, filters=filters, read_dictionary=categories
12101216
)
12111217
return data
12121218

@@ -1217,6 +1223,7 @@ def read_parquet(
12171223
columns: Optional[List[str]] = None,
12181224
chunked: bool = False,
12191225
dataset: bool = False,
1226+
categories: List[str] = None,
12201227
use_threads: bool = True,
12211228
boto3_session: Optional[boto3.Session] = None,
12221229
s3_additional_kwargs: Optional[Dict[str, str]] = None,
@@ -1243,6 +1250,9 @@ def read_parquet(
12431250
Otherwise return a single DataFrame with the whole data.
12441251
dataset: bool
12451252
If True read a parquet dataset instead of simple file(s) loading all the related partitions as columns.
1253+
categories: List[str], optional
1254+
List of columns names that should be returned as pandas.Categorical.
1255+
Recommended for memory restricted environments.
12461256
use_threads : bool
12471257
True to enable concurrent requests, False to disable multiple threads.
12481258
If enabled os.cpu_count() will be used as the max number of threads.
@@ -1292,66 +1302,59 @@ def read_parquet(
12921302
path=path,
12931303
filters=filters,
12941304
dataset=dataset,
1305+
categories=categories,
12951306
use_threads=use_threads,
12961307
boto3_session=boto3_session,
12971308
s3_additional_kwargs=s3_additional_kwargs,
12981309
)
1299-
common_metadata = data.common_metadata
1300-
common_metadata = None if common_metadata is None else common_metadata.metadata.get(b"pandas", None)
13011310
if chunked is False:
1302-
return _read_parquet(data=data, columns=columns, use_threads=use_threads, common_metadata=common_metadata)
1303-
return _read_parquet_chunked(data=data, columns=columns, use_threads=use_threads, common_metadata=common_metadata)
1311+
return _read_parquet(data=data, columns=columns, categories=categories, use_threads=use_threads)
1312+
return _read_parquet_chunked(data=data, columns=columns, categories=categories, use_threads=use_threads)
13041313

13051314

13061315
def _read_parquet(
13071316
data: pyarrow.parquet.ParquetDataset,
13081317
columns: Optional[List[str]] = None,
1318+
categories: List[str] = None,
13091319
use_threads: bool = True,
1310-
common_metadata: Any = None,
13111320
) -> pd.DataFrame:
1312-
# Data
13131321
tables: List[pa.Table] = []
13141322
for piece in data.pieces:
13151323
table: pa.Table = piece.read(
1316-
columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=True
1324+
columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=False
13171325
)
13181326
tables.append(table)
13191327
table = pa.lib.concat_tables(tables)
1320-
1321-
# Metadata
1322-
current_metadata = table.schema.metadata or {}
1323-
if common_metadata and b"pandas" not in current_metadata: # pragma: no cover
1324-
table = table.replace_schema_metadata({b"pandas": common_metadata})
1325-
13261328
return table.to_pandas(
13271329
use_threads=use_threads,
13281330
split_blocks=True,
13291331
self_destruct=True,
13301332
integer_object_nulls=False,
13311333
date_as_object=True,
1334+
ignore_metadata=True,
1335+
categories=categories,
13321336
types_mapper=_data_types.pyarrow2pandas_extension,
13331337
)
13341338

13351339

13361340
def _read_parquet_chunked(
13371341
data: pyarrow.parquet.ParquetDataset,
13381342
columns: Optional[List[str]] = None,
1343+
categories: List[str] = None,
13391344
use_threads: bool = True,
1340-
common_metadata: Any = None,
13411345
) -> Iterator[pd.DataFrame]:
13421346
for piece in data.pieces:
13431347
table: pa.Table = piece.read(
1344-
columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=True
1348+
columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=False
13451349
)
1346-
current_metadata = table.schema.metadata or {}
1347-
if common_metadata and b"pandas" not in current_metadata: # pragma: no cover
1348-
table = table.replace_schema_metadata({b"pandas": common_metadata})
13491350
yield table.to_pandas(
13501351
use_threads=use_threads,
13511352
split_blocks=True,
13521353
self_destruct=True,
13531354
integer_object_nulls=False,
13541355
date_as_object=True,
1356+
ignore_metadata=True,
1357+
categories=categories,
13551358
types_mapper=_data_types.pyarrow2pandas_extension,
13561359
)
13571360

@@ -1670,6 +1673,7 @@ def read_parquet_table(
16701673
database: str,
16711674
filters: Optional[Union[List[Tuple], List[List[Tuple]]]] = None,
16721675
columns: Optional[List[str]] = None,
1676+
categories: List[str] = None,
16731677
chunked: bool = False,
16741678
use_threads: bool = True,
16751679
boto3_session: Optional[boto3.Session] = None,
@@ -1690,7 +1694,10 @@ def read_parquet_table(
16901694
filters: Union[List[Tuple], List[List[Tuple]]], optional
16911695
List of filters to apply, like ``[[('x', '=', 0), ...], ...]``.
16921696
columns : List[str], optional
1693-
Names of columns to read from the file(s)
1697+
Names of columns to read from the file(s).
1698+
categories: List[str], optional
1699+
List of columns names that should be returned as pandas.Categorical.
1700+
Recommended for memory restricted environments.
16941701
chunked : bool
16951702
If True will break the data in smaller DataFrames (Non deterministic number of lines).
16961703
Otherwise return a single DataFrame with the whole data.
@@ -1740,6 +1747,7 @@ def read_parquet_table(
17401747
path=path,
17411748
filters=filters,
17421749
columns=columns,
1750+
categories=categories,
17431751
chunked=chunked,
17441752
dataset=True,
17451753
use_threads=use_threads,

testing/test_awswrangler/_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,25 @@ def get_df_cast():
9494
return df
9595

9696

97+
def get_df_category():
98+
df = pd.DataFrame(
99+
{
100+
"id": [1, 2, 3],
101+
"string_object": ["foo", None, "boo"],
102+
"string": ["foo", None, "boo"],
103+
"binary": [b"1", None, b"2"],
104+
"float": [1.0, None, 2.0],
105+
"int": [1, None, 2],
106+
"par0": [1, 1, 2],
107+
"par1": ["a", "b", "b"],
108+
}
109+
)
110+
df["string"] = df["string"].astype("string")
111+
df["int"] = df["int"].astype("Int64")
112+
df["par1"] = df["par1"].astype("string")
113+
return df
114+
115+
97116
def get_query_long():
98117
return """
99118
SELECT
@@ -324,3 +343,16 @@ def ensure_data_types(df, has_list=False):
324343
if has_list is True:
325344
assert str(type(row["list"][0]).__name__) == "int64"
326345
assert str(type(row["list_list"][0][0]).__name__) == "int64"
346+
347+
348+
def ensure_data_types_category(df):
349+
assert len(df.columns) in (7, 8)
350+
assert str(df["id"].dtype) in ("category", "Int64")
351+
assert str(df["string_object"].dtype) == "category"
352+
assert str(df["string"].dtype) == "category"
353+
if "binary" in df.columns:
354+
assert str(df["binary"].dtype) == "category"
355+
assert str(df["float"].dtype) == "category"
356+
assert str(df["int"].dtype) in ("category", "Int64")
357+
assert str(df["par0"].dtype) in ("category", "Int64")
358+
assert str(df["par1"].dtype) == "category"

0 commit comments

Comments
 (0)