Skip to content

Commit dff4aa6

Browse files
authored
Extending logic to add_csv_partitions and leveraging catalog_table_input (#674)
* Extending logic to add_csv_partitions and leveraging catalog_table_input * Adapting catalog versioning test
1 parent ca73c19 commit dff4aa6

File tree

5 files changed

+79
-22
lines changed

5 files changed

+79
-22
lines changed

awswrangler/catalog/_add.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@ def add_csv_partitions(
4848
catalog_id: Optional[str] = None,
4949
compression: Optional[str] = None,
5050
sep: str = ",",
51+
serde_library: Optional[str] = None,
52+
serde_parameters: Optional[Dict[str, str]] = None,
5153
boto3_session: Optional[boto3.Session] = None,
5254
columns_types: Optional[Dict[str, str]] = None,
5355
) -> None:
54-
"""Add partitions (metadata) to a CSV Table in the AWS Glue Catalog.
56+
r"""Add partitions (metadata) to a CSV Table in the AWS Glue Catalog.
5557
5658
Parameters
5759
----------
@@ -73,6 +75,13 @@ def add_csv_partitions(
7375
Compression style (``None``, ``gzip``, etc).
7476
sep : str
7577
String of length 1. Field delimiter for the output file.
78+
serde_library : Optional[str]
79+
Specifies the SerDe Serialization library which will be used. You need to provide the Class library name
80+
as a string.
81+
If no library is provided the default is `org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe`.
82+
serde_parameters : Optional[str]
83+
Dictionary of initialization parameters for the SerDe.
84+
The default is `{"field.delim": sep, "escape.delim": "\\"}`.
7685
boto3_session : boto3.Session(), optional
7786
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
7887
columns_types: Optional[Dict[str, str]]
@@ -107,6 +116,8 @@ def add_csv_partitions(
107116
compression=compression,
108117
sep=sep,
109118
columns_types=columns_types,
119+
serde_library=serde_library,
120+
serde_parameters=serde_parameters,
110121
)
111122
for k, v in partitions_values.items()
112123
]

awswrangler/catalog/_definitions.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,19 +152,24 @@ def _csv_partition_definition(
152152
bucketing_info: Optional[Tuple[List[str], int]],
153153
compression: Optional[str],
154154
sep: str,
155+
serde_library: Optional[str],
156+
serde_parameters: Optional[Dict[str, str]],
155157
columns_types: Optional[Dict[str, str]],
156158
) -> Dict[str, Any]:
157159
compressed: bool = compression is not None
160+
serde_info = {
161+
"SerializationLibrary": "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"
162+
if serde_library is None
163+
else serde_library,
164+
"Parameters": {"field.delim": sep, "escape.delim": "\\"} if serde_parameters is None else serde_parameters,
165+
}
158166
definition: Dict[str, Any] = {
159167
"StorageDescriptor": {
160168
"InputFormat": "org.apache.hadoop.mapred.TextInputFormat",
161169
"OutputFormat": "org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat",
162170
"Location": location,
163171
"Compressed": compressed,
164-
"SerdeInfo": {
165-
"Parameters": {"field.delim": sep, "escape.delim": "\\"},
166-
"SerializationLibrary": "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe",
167-
},
172+
"SerdeInfo": serde_info,
168173
"StoredAsSubDirectories": False,
169174
"NumberOfBuckets": -1 if bucketing_info is None else bucketing_info[1],
170175
"BucketColumns": [] if bucketing_info is None else bucketing_info[0],

awswrangler/s3/_write_text.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,11 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
501501
columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned(
502502
df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=True
503503
)
504+
serde_info: Dict[str, Any] = {}
505+
if catalog_table_input:
506+
serde_info = catalog_table_input["StorageDescriptor"]["SerdeInfo"]
507+
serde_library: Optional[str] = serde_info.get("SerializationLibrary", None)
508+
serde_parameters: Optional[Dict[str, str]] = serde_info.get("Parameters", None)
504509
catalog._create_csv_table( # pylint: disable=protected-access
505510
database=database,
506511
table=table,
@@ -525,8 +530,8 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
525530
catalog_id=catalog_id,
526531
compression=pandas_kwargs.get("compression"),
527532
skip_header_line_count=None,
528-
serde_library=None,
529-
serde_parameters=None,
533+
serde_library=serde_library,
534+
serde_parameters=serde_parameters,
530535
)
531536
if partitions_values and (regular_partitions is True):
532537
_logger.debug("partitions_values:\n%s", partitions_values)
@@ -537,6 +542,8 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
537542
bucketing_info=bucketing_info,
538543
boto3_session=session,
539544
sep=sep,
545+
serde_library=serde_library,
546+
serde_parameters=serde_parameters,
540547
catalog_id=catalog_id,
541548
columns_types=columns_types,
542549
compression=pandas_kwargs.get("compression"),

tests/test_athena_csv.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -451,15 +451,31 @@ def test_csv_compressed(path, glue_table, glue_database, use_threads, concurrent
451451
@pytest.mark.parametrize("use_threads", [True, False])
452452
@pytest.mark.parametrize("ctas_approach", [True, False])
453453
def test_opencsv_serde(path, glue_table, glue_database, use_threads, ctas_approach):
454-
df = pd.DataFrame({"c0": ['"1"', '"2"', '"3"'], "c1": ['"4"', '"5"', '"6"'], "c2": ['"a"', '"b"', '"c"']})
455-
wr.s3.to_csv(
456-
df=df, path=f"{path}0.csv", sep=",", index=False, header=False, use_threads=use_threads, quoting=csv.QUOTE_NONE
454+
df = pd.DataFrame({"col": ["1", "2", "3"], "col2": ["A", "A", "B"]})
455+
response = wr.s3.to_csv(
456+
df=df,
457+
path=path,
458+
dataset=True,
459+
partition_cols=["col2"],
460+
sep=",",
461+
index=False,
462+
header=False,
463+
use_threads=use_threads,
464+
quoting=csv.QUOTE_NONE,
457465
)
458466
wr.catalog.create_csv_table(
459467
database=glue_database,
460468
table=glue_table,
461469
path=path,
462-
columns_types={"c0": "string", "c1": "string", "c2": "string"},
470+
columns_types={"col": "string"},
471+
partitions_types={"col2": "string"},
472+
serde_library="org.apache.hadoop.hive.serde2.OpenCSVSerde",
473+
serde_parameters={"separatorChar": ",", "quoteChar": '"', "escapeChar": "\\"},
474+
)
475+
wr.catalog.add_csv_partitions(
476+
database=glue_database,
477+
table=glue_table,
478+
partitions_values=response["partitions_values"],
463479
serde_library="org.apache.hadoop.hive.serde2.OpenCSVSerde",
464480
serde_parameters={"separatorChar": ",", "quoteChar": '"', "escapeChar": "\\"},
465481
)

tests/test_catalog.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,11 @@ def test_catalog_get_databases(glue_database):
157157
assert db["Description"] == "AWS Data Wrangler Test Arena - Glue Database"
158158

159159

160-
def test_catalog_versioning(path, glue_database, glue_table):
160+
def test_catalog_versioning(path, glue_database, glue_table, glue_table2):
161161
wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table)
162162
wr.s3.delete_objects(path=path)
163163

164-
# Version 0
164+
# Version 1 - Parquet
165165
df = pd.DataFrame({"c0": [1, 2]})
166166
wr.s3.to_parquet(df=df, path=path, dataset=True, database=glue_database, table=glue_table, mode="overwrite")[
167167
"paths"
@@ -172,7 +172,7 @@ def test_catalog_versioning(path, glue_database, glue_table):
172172
assert len(df.columns) == 1
173173
assert str(df.c0.dtype).startswith("Int")
174174

175-
# Version 1
175+
# Version 2 - Parquet
176176
df = pd.DataFrame({"c1": ["foo", "boo"]})
177177
wr.s3.to_parquet(
178178
df=df,
@@ -189,38 +189,56 @@ def test_catalog_versioning(path, glue_database, glue_table):
189189
assert len(df.columns) == 1
190190
assert str(df.c1.dtype) == "string"
191191

192-
# Version 2
192+
# Version 1 - CSV
193193
df = pd.DataFrame({"c1": [1.0, 2.0]})
194194
wr.s3.to_csv(
195195
df=df,
196196
path=path,
197197
dataset=True,
198198
database=glue_database,
199-
table=glue_table,
199+
table=glue_table2,
200200
mode="overwrite",
201201
catalog_versioning=True,
202202
index=False,
203203
)
204-
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 3
205-
df = wr.athena.read_sql_table(table=glue_table, database=glue_database)
204+
assert wr.catalog.get_table_number_of_versions(table=glue_table2, database=glue_database) == 1
205+
df = wr.athena.read_sql_table(table=glue_table2, database=glue_database)
206206
assert len(df.index) == 2
207207
assert len(df.columns) == 1
208208
assert str(df.c1.dtype).startswith("float")
209209

210-
# Version 3 (removing version 2)
210+
# Version 1 - CSV (No evolution)
211211
df = pd.DataFrame({"c1": [True, False]})
212212
wr.s3.to_csv(
213213
df=df,
214214
path=path,
215215
dataset=True,
216216
database=glue_database,
217-
table=glue_table,
217+
table=glue_table2,
218218
mode="overwrite",
219219
catalog_versioning=False,
220220
index=False,
221221
)
222-
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 3
223-
df = wr.athena.read_sql_table(table=glue_table, database=glue_database)
222+
assert wr.catalog.get_table_number_of_versions(table=glue_table2, database=glue_database) == 1
223+
df = wr.athena.read_sql_table(table=glue_table2, database=glue_database)
224+
assert len(df.index) == 2
225+
assert len(df.columns) == 1
226+
assert str(df.c1.dtype).startswith("boolean")
227+
228+
# Version 2 - CSV
229+
df = pd.DataFrame({"c1": [True, False]})
230+
wr.s3.to_csv(
231+
df=df,
232+
path=path,
233+
dataset=True,
234+
database=glue_database,
235+
table=glue_table2,
236+
mode="overwrite",
237+
catalog_versioning=True,
238+
index=False,
239+
)
240+
assert wr.catalog.get_table_number_of_versions(table=glue_table2, database=glue_database) == 2
241+
df = wr.athena.read_sql_table(table=glue_table2, database=glue_database)
224242
assert len(df.index) == 2
225243
assert len(df.columns) == 1
226244
assert str(df.c1.dtype).startswith("boolean")

0 commit comments

Comments
 (0)