Skip to content

Commit ca73c19

Browse files
Add serde parameters to csv table creation (#673)
Co-authored-by: jaidisido <[email protected]>
1 parent 7173322 commit ca73c19

File tree

4 files changed

+56
-17
lines changed

4 files changed

+56
-17
lines changed

awswrangler/catalog/_create.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ def _create_csv_table(
296296
catalog_versioning: bool,
297297
sep: str,
298298
skip_header_line_count: Optional[int],
299+
serde_library: Optional[str],
300+
serde_parameters: Optional[Dict[str, str]],
299301
boto3_session: Optional[boto3.Session],
300302
projection_enabled: bool,
301303
projection_types: Optional[Dict[str, str]],
@@ -329,6 +331,8 @@ def _create_csv_table(
329331
compression=compression,
330332
sep=sep,
331333
skip_header_line_count=skip_header_line_count,
334+
serde_library=serde_library,
335+
serde_parameters=serde_parameters,
332336
)
333337
table_exist: bool = catalog_table_input is not None
334338
_logger.debug("table_exist: %s", table_exist)
@@ -670,6 +674,8 @@ def create_csv_table(
670674
catalog_versioning: bool = False,
671675
sep: str = ",",
672676
skip_header_line_count: Optional[int] = None,
677+
serde_library: Optional[str] = None,
678+
serde_parameters: Optional[Dict[str, str]] = None,
673679
boto3_session: Optional[boto3.Session] = None,
674680
projection_enabled: bool = False,
675681
projection_types: Optional[Dict[str, str]] = None,
@@ -679,7 +685,7 @@ def create_csv_table(
679685
projection_digits: Optional[Dict[str, str]] = None,
680686
catalog_id: Optional[str] = None,
681687
) -> None:
682-
"""Create a CSV Table (Metadata Only) in the AWS Glue Catalog.
688+
r"""Create a CSV Table (Metadata Only) in the AWS Glue Catalog.
683689
684690
'https://docs.aws.amazon.com/athena/latest/ug/data-types.html'
685691
@@ -715,6 +721,13 @@ def create_csv_table(
715721
String of length 1. Field delimiter for the output file.
716722
skip_header_line_count : Optional[int]
717723
Number of Lines to skip regarding to the header.
724+
serde_library : Optional[str]
725+
Specifies the SerDe Serialization library which will be used. You need to provide the Class library name
726+
as a string.
727+
If no library is provided the default is `org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe`.
728+
serde_parameters : Optional[str]
729+
Dictionary of initialization parameters for the SerDe.
730+
The default is `{"field.delim": sep, "escape.delim": "\\"}`.
718731
projection_enabled : bool
719732
Enable Partition Projection on Athena (https://docs.aws.amazon.com/athena/latest/ug/partition-projection.html)
720733
projection_types : Optional[Dict[str, str]]
@@ -793,4 +806,6 @@ def create_csv_table(
793806
catalog_table_input=catalog_table_input,
794807
sep=sep,
795808
skip_header_line_count=skip_header_line_count,
809+
serde_library=serde_library,
810+
serde_parameters=serde_parameters,
796811
)

awswrangler/catalog/_definitions.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def _csv_table_definition(
105105
compression: Optional[str],
106106
sep: str,
107107
skip_header_line_count: Optional[int],
108+
serde_library: Optional[str],
109+
serde_parameters: Optional[Dict[str, str]],
108110
) -> Dict[str, Any]:
109111
compressed: bool = compression is not None
110112
parameters: Dict[str, str] = {
@@ -116,7 +118,13 @@ def _csv_table_definition(
116118
"areColumnsQuoted": "false",
117119
}
118120
if skip_header_line_count is not None:
119-
parameters["skip.header.line.count"] = "1"
121+
parameters["skip.header.line.count"] = str(skip_header_line_count)
122+
serde_info = {
123+
"SerializationLibrary": "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"
124+
if serde_library is None
125+
else serde_library,
126+
"Parameters": {"field.delim": sep, "escape.delim": "\\"} if serde_parameters is None else serde_parameters,
127+
}
120128
return {
121129
"Name": table,
122130
"PartitionKeys": [{"Name": cname, "Type": dtype} for cname, dtype in partitions_types.items()],
@@ -129,21 +137,11 @@ def _csv_table_definition(
129137
"OutputFormat": "org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat",
130138
"Compressed": compressed,
131139
"NumberOfBuckets": -1 if bucketing_info is None else bucketing_info[1],
132-
"SerdeInfo": {
133-
"Parameters": {"field.delim": sep, "escape.delim": "\\"},
134-
"SerializationLibrary": "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe",
135-
},
140+
"SerdeInfo": serde_info,
136141
"BucketColumns": [] if bucketing_info is None else bucketing_info[0],
137142
"StoredAsSubDirectories": False,
138143
"SortColumns": [],
139-
"Parameters": {
140-
"classification": "csv",
141-
"compressionType": str(compression).lower(),
142-
"typeOfData": "file",
143-
"delimiter": sep,
144-
"columnsOrdered": "true",
145-
"areColumnsQuoted": "false",
146-
},
144+
"Parameters": parameters,
147145
},
148146
}
149147

awswrangler/s3/_write_text.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,8 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
525525
catalog_id=catalog_id,
526526
compression=pandas_kwargs.get("compression"),
527527
skip_header_line_count=None,
528+
serde_library=None,
529+
serde_parameters=None,
528530
)
529531
if partitions_values and (regular_partitions is True):
530532
_logger.debug("partitions_values:\n%s", partitions_values)

tests/test_athena_csv.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import csv
12
import logging
23
from sys import version_info
34

@@ -337,7 +338,8 @@ def test_athena_csv_types(path, glue_database, glue_table):
337338

338339
@pytest.mark.parametrize("use_threads", [True, False])
339340
@pytest.mark.parametrize("ctas_approach", [True, False])
340-
def test_skip_header(path, glue_database, glue_table, use_threads, ctas_approach):
341+
@pytest.mark.parametrize("line_count", [1, 2])
342+
def test_skip_header(path, glue_database, glue_table, use_threads, ctas_approach, line_count):
341343
df = pd.DataFrame({"c0": [1, 2], "c1": [3.3, 4.4], "c2": ["foo", "boo"]})
342344
df["c0"] = df["c0"].astype("Int64")
343345
df["c2"] = df["c2"].astype("string")
@@ -347,10 +349,10 @@ def test_skip_header(path, glue_database, glue_table, use_threads, ctas_approach
347349
table=glue_table,
348350
path=path,
349351
columns_types={"c0": "bigint", "c1": "double", "c2": "string"},
350-
skip_header_line_count=1,
352+
skip_header_line_count=line_count,
351353
)
352354
df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads, ctas_approach=ctas_approach)
353-
assert df.equals(df2)
355+
assert df.iloc[line_count - 1 :].reset_index(drop=True).equals(df2)
354356

355357

356358
@pytest.mark.parametrize("use_threads", [True, False])
@@ -444,3 +446,25 @@ def test_csv_compressed(path, glue_table, glue_database, use_threads, concurrent
444446
assert df2["id"].sum() == 6
445447
ensure_data_types_csv(df2)
446448
assert wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) is True
449+
450+
451+
@pytest.mark.parametrize("use_threads", [True, False])
452+
@pytest.mark.parametrize("ctas_approach", [True, False])
453+
def test_opencsv_serde(path, glue_table, glue_database, use_threads, ctas_approach):
454+
df = pd.DataFrame({"c0": ['"1"', '"2"', '"3"'], "c1": ['"4"', '"5"', '"6"'], "c2": ['"a"', '"b"', '"c"']})
455+
wr.s3.to_csv(
456+
df=df, path=f"{path}0.csv", sep=",", index=False, header=False, use_threads=use_threads, quoting=csv.QUOTE_NONE
457+
)
458+
wr.catalog.create_csv_table(
459+
database=glue_database,
460+
table=glue_table,
461+
path=path,
462+
columns_types={"c0": "string", "c1": "string", "c2": "string"},
463+
serde_library="org.apache.hadoop.hive.serde2.OpenCSVSerde",
464+
serde_parameters={"separatorChar": ",", "quoteChar": '"', "escapeChar": "\\"},
465+
)
466+
df2 = wr.athena.read_sql_table(
467+
table=glue_table, database=glue_database, use_threads=use_threads, ctas_approach=ctas_approach
468+
)
469+
df = df.applymap(lambda x: x.replace('"', "")).convert_dtypes()
470+
assert df.equals(df2)

0 commit comments

Comments
 (0)