Skip to content

Commit d63375e

Browse files
authored
Merge pull request #164 from awslabs/categories
Add categories argument for all read_parquet related functions
2 parents ee33dcc + 6b24a5d commit d63375e

File tree

7 files changed

+362
-33
lines changed

7 files changed

+362
-33
lines changed

awswrangler/athena.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def _extract_ctas_manifest_paths(path: str, boto3_session: Optional[boto3.Sessio
259259

260260

261261
def _get_query_metadata(
262-
query_execution_id: str, boto3_session: Optional[boto3.Session] = None
262+
query_execution_id: str, categories: List[str] = None, boto3_session: Optional[boto3.Session] = None
263263
) -> Tuple[Dict[str, str], List[str], List[str], Dict[str, Any], List[str]]:
264264
"""Get query metadata."""
265265
cols_types: Dict[str, str] = get_query_columns_types(
@@ -285,7 +285,9 @@ def _get_query_metadata(
285285
"Please use ctas_approach=True for Struct columns."
286286
)
287287
pandas_type: str = _data_types.athena2pandas(dtype=col_type)
288-
if pandas_type in ["datetime64", "date"]:
288+
if (categories is not None) and (col_name in categories):
289+
dtype[col_name] = "category"
290+
elif pandas_type in ["datetime64", "date"]:
289291
parse_timestamps.append(col_name)
290292
if pandas_type == "date":
291293
parse_dates.append(col_name)
@@ -326,6 +328,7 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
326328
sql: str,
327329
database: str,
328330
ctas_approach: bool = True,
331+
categories: List[str] = None,
329332
chunksize: Optional[int] = None,
330333
s3_output: Optional[str] = None,
331334
workgroup: Optional[str] = None,
@@ -377,6 +380,9 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
377380
ctas_approach: bool
378381
Wraps the query using a CTAS, and read the resulted parquet data on S3.
379382
If false, read the regular CSV on S3.
383+
categories: List[str], optional
384+
List of columns names that should be returned as pandas.Categorical.
385+
Recommended for memory restricted environments.
380386
chunksize: int, optional
381387
If specified, return an generator where chunksize is the number of rows to include in each chunk.
382388
s3_output : str, optional
@@ -457,10 +463,12 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals
457463
dfs = _utils.empty_generator()
458464
else:
459465
s3.wait_objects_exist(paths=paths, use_threads=False, boto3_session=session)
460-
dfs = s3.read_parquet(path=paths, use_threads=use_threads, boto3_session=session, chunked=chunked)
466+
dfs = s3.read_parquet(
467+
path=paths, use_threads=use_threads, boto3_session=session, chunked=chunked, categories=categories
468+
)
461469
return dfs
462470
dtype, parse_timestamps, parse_dates, converters, binaries = _get_query_metadata(
463-
query_execution_id=query_id, boto3_session=session
471+
query_execution_id=query_id, categories=categories, boto3_session=session
464472
)
465473
path = f"{_s3_output}{query_id}.csv"
466474
s3.wait_objects_exist(paths=[path], use_threads=False, boto3_session=session)

awswrangler/db.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,7 @@ def unload_redshift(
887887
path: str,
888888
con: sqlalchemy.engine.Engine,
889889
iam_role: str,
890+
categories: List[str] = None,
890891
chunked: bool = False,
891892
keep_files: bool = False,
892893
use_threads: bool = True,
@@ -920,6 +921,9 @@ def unload_redshift(
920921
wr.db.get_engine(), wr.db.get_redshift_temp_engine() or wr.catalog.get_engine()
921922
iam_role : str
922923
AWS IAM role with the related permissions.
924+
categories: List[str], optional
925+
List of columns names that should be returned as pandas.Categorical.
926+
Recommended for memory restricted environments.
923927
keep_files : bool
924928
Should keep the stage files?
925929
chunked : bool
@@ -960,6 +964,7 @@ def unload_redshift(
960964
return pd.DataFrame()
961965
df: pd.DataFrame = s3.read_parquet(
962966
path=paths,
967+
categories=categories,
963968
chunked=chunked,
964969
dataset=False,
965970
use_threads=use_threads,
@@ -973,6 +978,7 @@ def unload_redshift(
973978
return _utils.empty_generator()
974979
return _read_parquet_iterator(
975980
paths=paths,
981+
categories=categories,
976982
use_threads=use_threads,
977983
boto3_session=session,
978984
s3_additional_kwargs=s3_additional_kwargs,
@@ -984,11 +990,13 @@ def _read_parquet_iterator(
984990
paths: List[str],
985991
keep_files: bool,
986992
use_threads: bool,
993+
categories: List[str] = None,
987994
boto3_session: Optional[boto3.Session] = None,
988995
s3_additional_kwargs: Optional[Dict[str, str]] = None,
989996
) -> Iterator[pd.DataFrame]:
990997
dfs: Iterator[pd.DataFrame] = s3.read_parquet(
991998
path=paths,
999+
categories=categories,
9921000
chunked=True,
9931001
dataset=False,
9941002
use_threads=use_threads,

awswrangler/s3.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ def _to_parquet_file(
833833
fs: s3fs.S3FileSystem,
834834
dtype: Dict[str, str],
835835
) -> str:
836-
table: pa.Table = pyarrow.Table.from_pandas(df=df, schema=schema, nthreads=cpus, preserve_index=index, safe=False)
836+
table: pa.Table = pyarrow.Table.from_pandas(df=df, schema=schema, nthreads=cpus, preserve_index=index, safe=True)
837837
for col_name, col_type in dtype.items():
838838
if col_name in table.column_names:
839839
col_index = table.column_names.index(col_name)
@@ -1190,6 +1190,7 @@ def _read_text_full(
11901190
def _read_parquet_init(
11911191
path: Union[str, List[str]],
11921192
filters: Optional[Union[List[Tuple], List[List[Tuple]]]] = None,
1193+
categories: List[str] = None,
11931194
dataset: bool = False,
11941195
use_threads: bool = True,
11951196
boto3_session: Optional[boto3.Session] = None,
@@ -1206,7 +1207,7 @@ def _read_parquet_init(
12061207
fs: s3fs.S3FileSystem = _utils.get_fs(session=boto3_session, s3_additional_kwargs=s3_additional_kwargs)
12071208
cpus: int = _utils.ensure_cpu_count(use_threads=use_threads)
12081209
data: pyarrow.parquet.ParquetDataset = pyarrow.parquet.ParquetDataset(
1209-
path_or_paths=path_or_paths, filesystem=fs, metadata_nthreads=cpus, filters=filters
1210+
path_or_paths=path_or_paths, filesystem=fs, metadata_nthreads=cpus, filters=filters, read_dictionary=categories
12101211
)
12111212
return data
12121213

@@ -1217,6 +1218,7 @@ def read_parquet(
12171218
columns: Optional[List[str]] = None,
12181219
chunked: bool = False,
12191220
dataset: bool = False,
1221+
categories: List[str] = None,
12201222
use_threads: bool = True,
12211223
boto3_session: Optional[boto3.Session] = None,
12221224
s3_additional_kwargs: Optional[Dict[str, str]] = None,
@@ -1243,6 +1245,9 @@ def read_parquet(
12431245
Otherwise return a single DataFrame with the whole data.
12441246
dataset: bool
12451247
If True read a parquet dataset instead of simple file(s) loading all the related partitions as columns.
1248+
categories: List[str], optional
1249+
List of columns names that should be returned as pandas.Categorical.
1250+
Recommended for memory restricted environments.
12461251
use_threads : bool
12471252
True to enable concurrent requests, False to disable multiple threads.
12481253
If enabled os.cpu_count() will be used as the max number of threads.
@@ -1292,66 +1297,59 @@ def read_parquet(
12921297
path=path,
12931298
filters=filters,
12941299
dataset=dataset,
1300+
categories=categories,
12951301
use_threads=use_threads,
12961302
boto3_session=boto3_session,
12971303
s3_additional_kwargs=s3_additional_kwargs,
12981304
)
1299-
common_metadata = data.common_metadata
1300-
common_metadata = None if common_metadata is None else common_metadata.metadata.get(b"pandas", None)
13011305
if chunked is False:
1302-
return _read_parquet(data=data, columns=columns, use_threads=use_threads, common_metadata=common_metadata)
1303-
return _read_parquet_chunked(data=data, columns=columns, use_threads=use_threads, common_metadata=common_metadata)
1306+
return _read_parquet(data=data, columns=columns, categories=categories, use_threads=use_threads)
1307+
return _read_parquet_chunked(data=data, columns=columns, categories=categories, use_threads=use_threads)
13041308

13051309

13061310
def _read_parquet(
13071311
data: pyarrow.parquet.ParquetDataset,
13081312
columns: Optional[List[str]] = None,
1313+
categories: List[str] = None,
13091314
use_threads: bool = True,
1310-
common_metadata: Any = None,
13111315
) -> pd.DataFrame:
1312-
# Data
13131316
tables: List[pa.Table] = []
13141317
for piece in data.pieces:
13151318
table: pa.Table = piece.read(
1316-
columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=True
1319+
columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=False
13171320
)
13181321
tables.append(table)
13191322
table = pa.lib.concat_tables(tables)
1320-
1321-
# Metadata
1322-
current_metadata = table.schema.metadata or {}
1323-
if common_metadata and b"pandas" not in current_metadata: # pragma: no cover
1324-
table = table.replace_schema_metadata({b"pandas": common_metadata})
1325-
13261323
return table.to_pandas(
13271324
use_threads=use_threads,
13281325
split_blocks=True,
13291326
self_destruct=True,
13301327
integer_object_nulls=False,
13311328
date_as_object=True,
1329+
ignore_metadata=True,
1330+
categories=categories,
13321331
types_mapper=_data_types.pyarrow2pandas_extension,
13331332
)
13341333

13351334

13361335
def _read_parquet_chunked(
13371336
data: pyarrow.parquet.ParquetDataset,
13381337
columns: Optional[List[str]] = None,
1338+
categories: List[str] = None,
13391339
use_threads: bool = True,
1340-
common_metadata: Any = None,
13411340
) -> Iterator[pd.DataFrame]:
13421341
for piece in data.pieces:
13431342
table: pa.Table = piece.read(
1344-
columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=True
1343+
columns=columns, use_threads=use_threads, partitions=data.partitions, use_pandas_metadata=False
13451344
)
1346-
current_metadata = table.schema.metadata or {}
1347-
if common_metadata and b"pandas" not in current_metadata: # pragma: no cover
1348-
table = table.replace_schema_metadata({b"pandas": common_metadata})
13491345
yield table.to_pandas(
13501346
use_threads=use_threads,
13511347
split_blocks=True,
13521348
self_destruct=True,
13531349
integer_object_nulls=False,
13541350
date_as_object=True,
1351+
ignore_metadata=True,
1352+
categories=categories,
13551353
types_mapper=_data_types.pyarrow2pandas_extension,
13561354
)
13571355

@@ -1670,6 +1668,7 @@ def read_parquet_table(
16701668
database: str,
16711669
filters: Optional[Union[List[Tuple], List[List[Tuple]]]] = None,
16721670
columns: Optional[List[str]] = None,
1671+
categories: List[str] = None,
16731672
chunked: bool = False,
16741673
use_threads: bool = True,
16751674
boto3_session: Optional[boto3.Session] = None,
@@ -1690,7 +1689,10 @@ def read_parquet_table(
16901689
filters: Union[List[Tuple], List[List[Tuple]]], optional
16911690
List of filters to apply, like ``[[('x', '=', 0), ...], ...]``.
16921691
columns : List[str], optional
1693-
Names of columns to read from the file(s)
1692+
Names of columns to read from the file(s).
1693+
categories: List[str], optional
1694+
List of columns names that should be returned as pandas.Categorical.
1695+
Recommended for memory restricted environments.
16941696
chunked : bool
16951697
If True will break the data in smaller DataFrames (Non deterministic number of lines).
16961698
Otherwise return a single DataFrame with the whole data.
@@ -1740,6 +1742,7 @@ def read_parquet_table(
17401742
path=path,
17411743
filters=filters,
17421744
columns=columns,
1745+
categories=categories,
17431746
chunked=chunked,
17441747
dataset=True,
17451748
use_threads=use_threads,

testing/test_awswrangler/_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,25 @@ def get_df_cast():
9494
return df
9595

9696

97+
def get_df_category():
98+
df = pd.DataFrame(
99+
{
100+
"id": [1, 2, 3],
101+
"string_object": ["foo", None, "boo"],
102+
"string": ["foo", None, "boo"],
103+
"binary": [b"1", None, b"2"],
104+
"float": [1.0, None, 2.0],
105+
"int": [1, None, 2],
106+
"par0": [1, 1, 2],
107+
"par1": ["a", "b", "b"],
108+
}
109+
)
110+
df["string"] = df["string"].astype("string")
111+
df["int"] = df["int"].astype("Int64")
112+
df["par1"] = df["par1"].astype("string")
113+
return df
114+
115+
97116
def get_query_long():
98117
return """
99118
SELECT
@@ -324,3 +343,16 @@ def ensure_data_types(df, has_list=False):
324343
if has_list is True:
325344
assert str(type(row["list"][0]).__name__) == "int64"
326345
assert str(type(row["list_list"][0][0]).__name__) == "int64"
346+
347+
348+
def ensure_data_types_category(df):
349+
assert len(df.columns) in (7, 8)
350+
assert str(df["id"].dtype) in ("category", "Int64")
351+
assert str(df["string_object"].dtype) == "category"
352+
assert str(df["string"].dtype) == "category"
353+
if "binary" in df.columns:
354+
assert str(df["binary"].dtype) == "category"
355+
assert str(df["float"].dtype) == "category"
356+
assert str(df["int"].dtype) in ("category", "Int64")
357+
assert str(df["par0"].dtype) in ("category", "Int64")
358+
assert str(df["par1"].dtype) == "category"

testing/test_awswrangler/test_data_lake.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
import awswrangler as wr
99

10-
from ._utils import ensure_data_types, get_df, get_df_cast, get_df_list, get_query_long
10+
from ._utils import (ensure_data_types, ensure_data_types_category, get_df, get_df_cast, get_df_category, get_df_list,
11+
get_query_long)
1112

1213
logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s")
1314
logging.getLogger("awswrangler").setLevel(logging.DEBUG)
@@ -614,3 +615,38 @@ def test_athena_time_zone(database):
614615
assert len(df.columns) == 2
615616
assert df["type"][0] == "timestamp with time zone"
616617
assert df["value"][0].year == datetime.datetime.utcnow().year
618+
619+
620+
def test_category(bucket, database):
621+
df = get_df_category()
622+
path = f"s3://{bucket}/test_category/"
623+
paths = wr.s3.to_parquet(
624+
df=df,
625+
path=path,
626+
dataset=True,
627+
database=database,
628+
table="test_category",
629+
mode="overwrite",
630+
partition_cols=["par0", "par1"],
631+
)["paths"]
632+
wr.s3.wait_objects_exist(paths=paths, use_threads=False)
633+
df2 = wr.s3.read_parquet(path=path, dataset=True, categories=[c for c in df.columns if c not in ["par0", "par1"]])
634+
ensure_data_types_category(df2)
635+
df2 = wr.athena.read_sql_query("SELECT * FROM test_category", database=database, categories=list(df.columns))
636+
ensure_data_types_category(df2)
637+
df2 = wr.athena.read_sql_query(
638+
"SELECT * FROM test_category", database=database, categories=list(df.columns), ctas_approach=False
639+
)
640+
ensure_data_types_category(df2)
641+
dfs = wr.athena.read_sql_query(
642+
"SELECT * FROM test_category", database=database, categories=list(df.columns), ctas_approach=False, chunksize=1
643+
)
644+
for df2 in dfs:
645+
ensure_data_types_category(df2)
646+
dfs = wr.athena.read_sql_query(
647+
"SELECT * FROM test_category", database=database, categories=list(df.columns), ctas_approach=True, chunksize=1
648+
)
649+
for df2 in dfs:
650+
ensure_data_types_category(df2)
651+
wr.s3.delete_objects(path=paths)
652+
assert wr.catalog.delete_table_if_exists(database=database, table="test_category") is True

testing/test_awswrangler/test_db.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import awswrangler as wr
1111

12-
from ._utils import ensure_data_types, get_df
12+
from ._utils import ensure_data_types, ensure_data_types_category, get_df, get_df_category
1313

1414
logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s")
1515
logging.getLogger("awswrangler").setLevel(logging.DEBUG)
@@ -348,3 +348,39 @@ def test_redshift_spectrum(bucket, glue_database, external_schema):
348348
assert len(rows) == len(df.index)
349349
for row in rows:
350350
assert len(row) == len(df.columns)
351+
352+
353+
def test_redshift_category(bucket, parameters):
354+
path = f"s3://{bucket}/test_redshift_category/"
355+
df = get_df_category().drop(["binary"], axis=1, inplace=False)
356+
engine = wr.catalog.get_engine(connection=f"aws-data-wrangler-redshift")
357+
wr.db.copy_to_redshift(
358+
df=df,
359+
path=path,
360+
con=engine,
361+
schema="public",
362+
table="test_redshift_category",
363+
mode="overwrite",
364+
iam_role=parameters["redshift"]["role"],
365+
)
366+
df2 = wr.db.unload_redshift(
367+
sql="SELECT * FROM public.test_redshift_category",
368+
con=engine,
369+
iam_role=parameters["redshift"]["role"],
370+
path=path,
371+
keep_files=False,
372+
categories=df.columns,
373+
)
374+
ensure_data_types_category(df2)
375+
dfs = wr.db.unload_redshift(
376+
sql="SELECT * FROM public.test_redshift_category",
377+
con=engine,
378+
iam_role=parameters["redshift"]["role"],
379+
path=path,
380+
keep_files=False,
381+
categories=df.columns,
382+
chunked=True,
383+
)
384+
for df2 in dfs:
385+
ensure_data_types_category(df2)
386+
wr.s3.delete_objects(path=path)

0 commit comments

Comments
 (0)