Skip to content

Commit ca831ef

Browse files
committed
Inverting file format and file compression extensions (key prefix)
1 parent 120bb68 commit ca831ef

File tree

2 files changed

+29
-20
lines changed

2 files changed

+29
-20
lines changed

awswrangler/pandas.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ def to_csv(self,
682682
"""
683683
if serde not in Pandas.VALID_CSV_SERDES:
684684
raise InvalidSerDe(f"{serde} in not in the valid SerDe list ({Pandas.VALID_CSV_SERDES})")
685-
extra_args = {"sep": sep, "serde": serde, "escapechar": escapechar}
685+
extra_args: Dict[str, Optional[str]] = {"sep": sep, "serde": serde, "escapechar": escapechar}
686686
return self.to_s3(dataframe=dataframe,
687687
path=path,
688688
file_format="csv",
@@ -767,7 +767,7 @@ def to_s3(self,
767767
procs_cpu_bound=None,
768768
procs_io_bound=None,
769769
cast_columns=None,
770-
extra_args=None,
770+
extra_args: Optional[Dict[str, Optional[str]]] = None,
771771
inplace: bool = True,
772772
description: Optional[str] = None,
773773
parameters: Optional[Dict[str, str]] = None,
@@ -922,7 +922,7 @@ def _data_to_s3_dataset_writer(dataframe: pd.DataFrame,
922922
session_primitives: "SessionPrimitives",
923923
file_format: str,
924924
cast_columns=None,
925-
extra_args=None,
925+
extra_args: Optional[Dict[str, Optional[str]]] = None,
926926
isolated_dataframe: bool = False):
927927
objects_paths = []
928928
dataframe = Pandas._cast_pandas(dataframe=dataframe, cast_columns=cast_columns)
@@ -980,7 +980,7 @@ def _data_to_s3_dataset_writer_remote(send_pipe,
980980
session_primitives: "SessionPrimitives",
981981
file_format,
982982
cast_columns=None,
983-
extra_args=None):
983+
extra_args: Optional[Dict[str, Optional[str]]] = None):
984984
send_pipe.send(
985985
Pandas._data_to_s3_dataset_writer(dataframe=dataframe,
986986
path=path,
@@ -996,35 +996,35 @@ def _data_to_s3_dataset_writer_remote(send_pipe,
996996

997997
@staticmethod
998998
def _data_to_s3_object_writer(dataframe: pd.DataFrame,
999-
path: "str",
999+
path: str,
10001000
preserve_index: bool,
1001-
compression,
1001+
compression: str,
10021002
session_primitives: "SessionPrimitives",
1003-
file_format,
1004-
cast_columns=None,
1005-
extra_args=None,
1006-
isolated_dataframe=False):
1003+
file_format: str,
1004+
cast_columns: Optional[List[str]] = None,
1005+
extra_args: Optional[Dict[str, Optional[str]]] = None,
1006+
isolated_dataframe=False) -> str:
10071007
fs = get_fs(session_primitives=session_primitives)
10081008
fs = pa.filesystem._ensure_filesystem(fs)
10091009
mkdir_if_not_exists(fs, path)
10101010

10111011
if compression is None:
1012-
compression_end = ""
1012+
compression_extension: str = ""
10131013
elif compression == "snappy":
1014-
compression_end = ".snappy"
1014+
compression_extension = ".snappy"
10151015
elif compression == "gzip":
1016-
compression_end = ".gz"
1016+
compression_extension = ".gz"
10171017
else:
10181018
raise InvalidCompression(compression)
10191019

1020-
guid = pa.compat.guid()
1020+
guid: str = pa.compat.guid()
10211021
if file_format == "parquet":
1022-
outfile = f"{guid}.parquet{compression_end}"
1022+
outfile: str = f"{guid}{compression_extension}.parquet"
10231023
elif file_format == "csv":
1024-
outfile = f"{guid}.csv{compression_end}"
1024+
outfile = f"{guid}{compression_extension}.csv"
10251025
else:
10261026
raise UnsupportedFileFormat(file_format)
1027-
object_path = "/".join([path, outfile])
1027+
object_path: str = "/".join([path, outfile])
10281028
if file_format == "parquet":
10291029
Pandas.write_parquet_dataframe(dataframe=dataframe,
10301030
path=object_path,

testing/test_awswrangler/test_pandas.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,7 +1927,12 @@ def test_aurora_postgres_load_special(bucket, postgres_parameters):
19271927
"value": ["foo", "boo", "bar", "abc"],
19281928
"slashes": ["\\", "\"", "\\\\\\\\", "\"\"\"\""],
19291929
"floats": [1.0, 2.0, 3.0, 4.0],
1930-
"decimals": [Decimal((0, (1, 9, 9), -2)), Decimal((0, (1, 9, 9), -2)), Decimal((0, (1, 9, 0), -2)), Decimal((0, (3, 1, 2), -2))]
1930+
"decimals": [
1931+
Decimal((0, (1, 9, 9), -2)),
1932+
Decimal((0, (1, 9, 9), -2)),
1933+
Decimal((0, (1, 9, 0), -2)),
1934+
Decimal((0, (3, 1, 2), -2))
1935+
]
19311936
})
19321937

19331938
path = f"s3://{bucket}/test_aurora_postgres_special"
@@ -1977,8 +1982,12 @@ def test_aurora_mysql_load_special(bucket, mysql_parameters):
19771982
"value": ["foo", "boo", "bar", "abc"],
19781983
"slashes": ["\\", "\"", "\\\\\\\\", "\"\"\"\""],
19791984
"floats": [1.0, 2.0, 3.0, 4.0],
1980-
"decimals": [Decimal((0, (1, 9, 9), -2)), Decimal((0, (1, 9, 9), -2)), Decimal((0, (1, 9, 0), -2)),
1981-
Decimal((0, (3, 1, 2), -2))]
1985+
"decimals": [
1986+
Decimal((0, (1, 9, 9), -2)),
1987+
Decimal((0, (1, 9, 9), -2)),
1988+
Decimal((0, (1, 9, 0), -2)),
1989+
Decimal((0, (3, 1, 2), -2))
1990+
]
19821991
})
19831992

19841993
path = f"s3://{bucket}/test_aurora_mysql_special"

0 commit comments

Comments
 (0)