Skip to content

Commit f86f3b1

Browse files
authored
(enhancement): Reduce LOC in S3 write methods create_table (#1626)
* (enhancement): Reduce LOC in S3 write methods create_table * Minor - Missing catalog table input
1 parent 6d8ae70 commit f86f3b1

File tree

2 files changed

+106
-182
lines changed

2 files changed

+106
-182
lines changed

awswrangler/s3/_write_parquet.py

Lines changed: 30 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,6 @@ def _to_parquet(
181181
path_root: Optional[str] = None,
182182
filename_prefix: Optional[str] = uuid.uuid4().hex,
183183
max_rows_by_file: Optional[int] = 0,
184-
# bucketing: bool = False,
185184
) -> List[str]:
186185
file_path = _get_file_path(
187186
path_root=path_root, path=path, filename_prefix=filename_prefix, compression_ext=compression_ext
@@ -679,34 +678,36 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals,too-many-b
679678
if schema_evolution is False:
680679
_utils.check_schema_changes(columns_types=columns_types, table_input=catalog_table_input, mode=mode)
681680

681+
create_table_args: Dict[str, Any] = {
682+
"database": database,
683+
"table": table,
684+
"path": path,
685+
"columns_types": columns_types,
686+
"table_type": table_type,
687+
"partitions_types": partitions_types,
688+
"bucketing_info": bucketing_info,
689+
"compression": compression,
690+
"description": description,
691+
"parameters": parameters,
692+
"columns_comments": columns_comments,
693+
"boto3_session": session,
694+
"mode": mode,
695+
"transaction_id": transaction_id,
696+
"catalog_versioning": catalog_versioning,
697+
"projection_enabled": projection_enabled,
698+
"projection_types": projection_types,
699+
"projection_ranges": projection_ranges,
700+
"projection_values": projection_values,
701+
"projection_intervals": projection_intervals,
702+
"projection_digits": projection_digits,
703+
"projection_storage_location_template": None,
704+
"catalog_id": catalog_id,
705+
"catalog_table_input": catalog_table_input,
706+
}
707+
682708
if (catalog_table_input is None) and (table_type == "GOVERNED"):
683-
catalog._create_parquet_table( # pylint: disable=protected-access
684-
database=database,
685-
table=table,
686-
path=path, # type: ignore
687-
columns_types=columns_types,
688-
table_type=table_type,
689-
partitions_types=partitions_types,
690-
bucketing_info=bucketing_info,
691-
compression=compression,
692-
description=description,
693-
parameters=parameters,
694-
columns_comments=columns_comments,
695-
boto3_session=session,
696-
mode=mode,
697-
transaction_id=transaction_id,
698-
catalog_versioning=catalog_versioning,
699-
projection_enabled=projection_enabled,
700-
projection_types=projection_types,
701-
projection_ranges=projection_ranges,
702-
projection_values=projection_values,
703-
projection_intervals=projection_intervals,
704-
projection_digits=projection_digits,
705-
projection_storage_location_template=None,
706-
catalog_id=catalog_id,
707-
catalog_table_input=catalog_table_input,
708-
)
709-
catalog_table_input = catalog._get_table_input( # pylint: disable=protected-access
709+
catalog._create_parquet_table(**create_table_args) # pylint: disable=protected-access
710+
create_table_args["catalog_table_input"] = catalog._get_table_input( # pylint: disable=protected-access
710711
database=database,
711712
table=table,
712713
boto3_session=session,
@@ -743,32 +744,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals,too-many-b
743744
)
744745
if (database is not None) and (table is not None):
745746
try:
746-
catalog._create_parquet_table( # pylint: disable=protected-access
747-
database=database,
748-
table=table,
749-
path=path, # type: ignore
750-
columns_types=columns_types,
751-
table_type=table_type,
752-
partitions_types=partitions_types,
753-
bucketing_info=bucketing_info,
754-
compression=compression,
755-
description=description,
756-
parameters=parameters,
757-
columns_comments=columns_comments,
758-
boto3_session=session,
759-
mode=mode,
760-
transaction_id=transaction_id,
761-
catalog_versioning=catalog_versioning,
762-
projection_enabled=projection_enabled,
763-
projection_types=projection_types,
764-
projection_ranges=projection_ranges,
765-
projection_values=projection_values,
766-
projection_intervals=projection_intervals,
767-
projection_digits=projection_digits,
768-
projection_storage_location_template=None,
769-
catalog_id=catalog_id,
770-
catalog_table_input=catalog_table_input,
771-
)
747+
catalog._create_parquet_table(**create_table_args) # pylint: disable=protected-access
772748
if partitions_values and (regular_partitions is True) and (table_type != "GOVERNED"):
773749
_logger.debug("partitions_values:\n%s", partitions_values)
774750
catalog.add_parquet_partitions(

awswrangler/s3/_write_text.py

Lines changed: 76 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -537,45 +537,48 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
537537
if schema_evolution is False:
538538
_utils.check_schema_changes(columns_types=columns_types, table_input=catalog_table_input, mode=mode)
539539

540+
create_table_args: Dict[str, Any] = {
541+
"database": database,
542+
"table": table,
543+
"path": path,
544+
"columns_types": columns_types,
545+
"table_type": table_type,
546+
"partitions_types": partitions_types,
547+
"bucketing_info": bucketing_info,
548+
"description": description,
549+
"parameters": parameters,
550+
"columns_comments": columns_comments,
551+
"boto3_session": session,
552+
"mode": mode,
553+
"transaction_id": transaction_id,
554+
"schema_evolution": schema_evolution,
555+
"catalog_versioning": catalog_versioning,
556+
"sep": sep,
557+
"projection_enabled": projection_enabled,
558+
"projection_types": projection_types,
559+
"projection_ranges": projection_ranges,
560+
"projection_values": projection_values,
561+
"projection_intervals": projection_intervals,
562+
"projection_digits": projection_digits,
563+
"projection_storage_location_template": None,
564+
"catalog_table_input": catalog_table_input,
565+
"catalog_id": catalog_id,
566+
"compression": pandas_kwargs.get("compression"),
567+
"skip_header_line_count": True if header else None,
568+
"serde_library": None,
569+
"serde_parameters": None,
570+
}
571+
540572
if (catalog_table_input is None) and (table_type == "GOVERNED"):
541-
catalog._create_csv_table( # pylint: disable=protected-access
542-
database=database,
543-
table=table,
544-
path=path,
545-
columns_types=columns_types,
546-
table_type=table_type,
547-
partitions_types=partitions_types,
548-
bucketing_info=bucketing_info,
549-
description=description,
550-
parameters=parameters,
551-
columns_comments=columns_comments,
552-
boto3_session=session,
553-
mode=mode,
554-
transaction_id=transaction_id,
555-
schema_evolution=schema_evolution,
556-
catalog_versioning=catalog_versioning,
557-
sep=sep,
558-
projection_enabled=projection_enabled,
559-
projection_types=projection_types,
560-
projection_ranges=projection_ranges,
561-
projection_values=projection_values,
562-
projection_intervals=projection_intervals,
563-
projection_digits=projection_digits,
564-
projection_storage_location_template=None,
565-
catalog_table_input=catalog_table_input,
566-
catalog_id=catalog_id,
567-
compression=pandas_kwargs.get("compression"),
568-
skip_header_line_count=None,
569-
serde_library=None,
570-
serde_parameters=None,
571-
)
573+
catalog._create_csv_table(**create_table_args) # pylint: disable=protected-access
572574
catalog_table_input = catalog._get_table_input( # pylint: disable=protected-access
573575
database=database,
574576
table=table,
575577
boto3_session=session,
576578
transaction_id=transaction_id,
577579
catalog_id=catalog_id,
578580
)
581+
create_table_args["catalog_table_input"] = catalog_table_input
579582

580583
paths, partitions_values = _to_dataset(
581584
func=_to_text,
@@ -610,39 +613,9 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
610613
serde_info: Dict[str, Any] = {}
611614
if catalog_table_input:
612615
serde_info = catalog_table_input["StorageDescriptor"]["SerdeInfo"]
613-
serde_library: Optional[str] = serde_info.get("SerializationLibrary", None)
614-
serde_parameters: Optional[Dict[str, str]] = serde_info.get("Parameters", None)
615-
catalog._create_csv_table( # pylint: disable=protected-access
616-
database=database,
617-
table=table,
618-
path=path,
619-
columns_types=columns_types,
620-
table_type=table_type,
621-
partitions_types=partitions_types,
622-
bucketing_info=bucketing_info,
623-
description=description,
624-
parameters=parameters,
625-
columns_comments=columns_comments,
626-
boto3_session=session,
627-
mode=mode,
628-
transaction_id=transaction_id,
629-
catalog_versioning=catalog_versioning,
630-
schema_evolution=schema_evolution,
631-
sep=sep,
632-
projection_enabled=projection_enabled,
633-
projection_types=projection_types,
634-
projection_ranges=projection_ranges,
635-
projection_values=projection_values,
636-
projection_intervals=projection_intervals,
637-
projection_digits=projection_digits,
638-
projection_storage_location_template=None,
639-
catalog_table_input=catalog_table_input,
640-
catalog_id=catalog_id,
641-
compression=pandas_kwargs.get("compression"),
642-
skip_header_line_count=True if header else None,
643-
serde_library=serde_library,
644-
serde_parameters=serde_parameters,
645-
)
616+
create_table_args["serde_library"] = serde_info.get("SerializationLibrary", None)
617+
create_table_args["serde_parameters"] = serde_info.get("Parameters", None)
618+
catalog._create_csv_table(**create_table_args) # pylint: disable=protected-access
646619
if partitions_values and (regular_partitions is True) and (table_type != "GOVERNED"):
647620
_logger.debug("partitions_values:\n%s", partitions_values)
648621
catalog.add_csv_partitions(
@@ -652,8 +625,8 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
652625
bucketing_info=bucketing_info,
653626
boto3_session=session,
654627
sep=sep,
655-
serde_library=serde_library,
656-
serde_parameters=serde_parameters,
628+
serde_library=create_table_args["serde_library"],
629+
serde_parameters=create_table_args["serde_parameters"],
657630
catalog_id=catalog_id,
658631
columns_types=columns_types,
659632
compression=pandas_kwargs.get("compression"),
@@ -969,43 +942,46 @@ def to_json( # pylint: disable=too-many-arguments,too-many-locals,too-many-stat
969942
if schema_evolution is False:
970943
_utils.check_schema_changes(columns_types=columns_types, table_input=catalog_table_input, mode=mode)
971944

945+
create_table_args: Dict[str, Any] = {
946+
"database": database,
947+
"table": table,
948+
"path": path,
949+
"columns_types": columns_types,
950+
"table_type": table_type,
951+
"partitions_types": partitions_types,
952+
"bucketing_info": bucketing_info,
953+
"description": description,
954+
"parameters": parameters,
955+
"columns_comments": columns_comments,
956+
"boto3_session": session,
957+
"mode": mode,
958+
"transaction_id": transaction_id,
959+
"catalog_versioning": catalog_versioning,
960+
"schema_evolution": schema_evolution,
961+
"projection_enabled": projection_enabled,
962+
"projection_types": projection_types,
963+
"projection_ranges": projection_ranges,
964+
"projection_values": projection_values,
965+
"projection_intervals": projection_intervals,
966+
"projection_digits": projection_digits,
967+
"projection_storage_location_template": None,
968+
"catalog_table_input": catalog_table_input,
969+
"catalog_id": catalog_id,
970+
"compression": compression,
971+
"serde_library": None,
972+
"serde_parameters": None,
973+
}
974+
972975
if (catalog_table_input is None) and (table_type == "GOVERNED"):
973-
catalog._create_json_table( # pylint: disable=protected-access
974-
database=database,
975-
table=table,
976-
path=path, # type: ignore
977-
columns_types=columns_types,
978-
table_type=table_type,
979-
partitions_types=partitions_types,
980-
bucketing_info=bucketing_info,
981-
description=description,
982-
parameters=parameters,
983-
columns_comments=columns_comments,
984-
boto3_session=session,
985-
mode=mode,
986-
transaction_id=transaction_id,
987-
catalog_versioning=catalog_versioning,
988-
schema_evolution=schema_evolution,
989-
projection_enabled=projection_enabled,
990-
projection_types=projection_types,
991-
projection_ranges=projection_ranges,
992-
projection_values=projection_values,
993-
projection_intervals=projection_intervals,
994-
projection_digits=projection_digits,
995-
projection_storage_location_template=None,
996-
catalog_table_input=catalog_table_input,
997-
catalog_id=catalog_id,
998-
compression=compression,
999-
serde_library=None,
1000-
serde_parameters=None,
1001-
)
976+
catalog._create_json_table(**create_table_args) # pylint: disable=protected-access
1002977
catalog_table_input = catalog._get_table_input( # pylint: disable=protected-access
1003978
database=database,
1004979
table=table,
1005980
boto3_session=session,
1006981
transaction_id=transaction_id,
1007982
catalog_id=catalog_id,
1008983
)
984+
create_table_args["catalog_table_input"] = catalog_table_input
1009985

1010986
paths, partitions_values = _to_dataset(
1011987
func=_to_text,
@@ -1035,37 +1011,9 @@ def to_json( # pylint: disable=too-many-arguments,too-many-locals,too-many-stat
10351011
serde_info: Dict[str, Any] = {}
10361012
if catalog_table_input:
10371013
serde_info = catalog_table_input["StorageDescriptor"]["SerdeInfo"]
1038-
serde_library: Optional[str] = serde_info.get("SerializationLibrary", None)
1039-
serde_parameters: Optional[Dict[str, str]] = serde_info.get("Parameters", None)
1040-
catalog._create_json_table( # pylint: disable=protected-access
1041-
database=database,
1042-
table=table,
1043-
path=path, # type: ignore
1044-
columns_types=columns_types,
1045-
table_type=table_type,
1046-
partitions_types=partitions_types,
1047-
bucketing_info=bucketing_info,
1048-
description=description,
1049-
parameters=parameters,
1050-
columns_comments=columns_comments,
1051-
boto3_session=session,
1052-
mode=mode,
1053-
transaction_id=transaction_id,
1054-
catalog_versioning=catalog_versioning,
1055-
schema_evolution=schema_evolution,
1056-
projection_enabled=projection_enabled,
1057-
projection_types=projection_types,
1058-
projection_ranges=projection_ranges,
1059-
projection_values=projection_values,
1060-
projection_intervals=projection_intervals,
1061-
projection_digits=projection_digits,
1062-
projection_storage_location_template=None,
1063-
catalog_table_input=catalog_table_input,
1064-
catalog_id=catalog_id,
1065-
compression=compression,
1066-
serde_library=serde_library,
1067-
serde_parameters=serde_parameters,
1068-
)
1014+
create_table_args["serde_library"] = serde_info.get("SerializationLibrary", None)
1015+
create_table_args["serde_parameters"] = serde_info.get("Parameters", None)
1016+
catalog._create_json_table(**create_table_args) # pylint: disable=protected-access
10691017
if partitions_values and (regular_partitions is True) and (table_type != "GOVERNED"):
10701018
_logger.debug("partitions_values:\n%s", partitions_values)
10711019
catalog.add_json_partitions(
@@ -1074,8 +1022,8 @@ def to_json( # pylint: disable=too-many-arguments,too-many-locals,too-many-stat
10741022
partitions_values=partitions_values,
10751023
bucketing_info=bucketing_info,
10761024
boto3_session=session,
1077-
serde_library=serde_library,
1078-
serde_parameters=serde_parameters,
1025+
serde_library=create_table_args["serde_library"],
1026+
serde_parameters=create_table_args["serde_parameters"],
10791027
catalog_id=catalog_id,
10801028
columns_types=columns_types,
10811029
compression=compression,

0 commit comments

Comments
 (0)