Skip to content

Commit bd5526e

Browse files
authored
Major - Make path argument in wr.s3.to functions optional. Throw exception if existing table path is different to input (#598)
1 parent 4ac86d5 commit bd5526e

File tree

6 files changed

+64
-18
lines changed

6 files changed

+64
-18
lines changed

awswrangler/s3/_write.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _validate_args(
4747
table: Optional[str],
4848
database: Optional[str],
4949
dataset: bool,
50-
path: str,
50+
path: Optional[str],
5151
partition_cols: Optional[List[str]],
5252
bucketing_info: Optional[Tuple[List[str], int]],
5353
mode: Optional[str],
@@ -58,6 +58,8 @@ def _validate_args(
5858
if df.empty is True:
5959
raise exceptions.EmptyDataFrame()
6060
if dataset is False:
61+
if path is None:
62+
raise exceptions.InvalidArgumentValue("If dataset is False, the `path` argument must be passed.")
6163
if path.endswith("/"):
6264
raise exceptions.InvalidArgumentValue(
6365
"If <dataset=False>, the argument <path> should be a file path, not a directory."
@@ -79,6 +81,10 @@ def _validate_args(
7981
"Arguments database and table must be passed together. If you want to store your dataset metadata in "
8082
"the Glue Catalog, please ensure you are passing both."
8183
)
84+
elif all(x is None for x in [path, database, table]):
85+
raise exceptions.InvalidArgumentCombination(
86+
"You must specify a `path` if dataset is True and database/table are not enabled."
87+
)
8288
elif bucketing_info and bucketing_info[1] <= 0:
8389
raise exceptions.InvalidArgumentValue(
8490
"Please pass a value greater than 1 for the number of buckets for bucketing."

awswrangler/s3/_write_parquet.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _to_parquet(
198198
@apply_configs
199199
def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
200200
df: pd.DataFrame,
201-
path: str,
201+
path: Optional[str] = None,
202202
index: bool = False,
203203
compression: Optional[str] = "snappy",
204204
max_rows_by_file: Optional[int] = None,
@@ -252,8 +252,9 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
252252
----------
253253
df: pandas.DataFrame
254254
Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html
255-
path : str
255+
path : str, optional
256256
S3 path (for file e.g. ``s3://bucket/prefix/filename.parquet``) (for dataset e.g. ``s3://bucket/prefix``).
257+
Required if dataset=False or when dataset=True and creating a new dataset
257258
index : bool
258259
True to store the DataFrame index in file, otherwise False to ignore it.
259260
compression: str, optional
@@ -511,6 +512,19 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
511512
catalog_table_input = catalog._get_table_input( # pylint: disable=protected-access
512513
database=database, table=table, boto3_session=session, catalog_id=catalog_id
513514
)
515+
catalog_path = catalog_table_input["StorageDescriptor"]["Location"] if catalog_table_input else None
516+
if path is None:
517+
if catalog_path:
518+
path = catalog_path
519+
else:
520+
raise exceptions.InvalidArgumentValue(
521+
"Glue table does not exist in the catalog. Please pass the `path` argument to create it."
522+
)
523+
elif path and catalog_path:
524+
if path.rstrip("/") != catalog_path.rstrip("/"):
525+
raise exceptions.InvalidArgumentValue(
526+
f"The specified path: {path}, does not match the existing Glue catalog table path: {catalog_path}"
527+
)
514528
df = _apply_dtype(df=df, dtype=dtype, catalog_table_input=catalog_table_input, mode=mode)
515529
schema: pa.Schema = _data_types.pyarrow_schema_from_pandas(
516530
df=df, index=index, ignore_cols=partition_cols, dtype=dtype
@@ -545,7 +559,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
545559
func=_to_parquet,
546560
concurrent_partitioning=concurrent_partitioning,
547561
df=df,
548-
path_root=path,
562+
path_root=path, # type: ignore
549563
index=index,
550564
compression=compression,
551565
compression_ext=compression_ext,
@@ -565,7 +579,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
565579
catalog._create_parquet_table( # pylint: disable=protected-access
566580
database=database,
567581
table=table,
568-
path=path,
582+
path=path, # type: ignore
569583
columns_types=columns_types,
570584
partitions_types=partitions_types,
571585
bucketing_info=bucketing_info,

awswrangler/s3/_write_text.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ def _to_text(
7272

7373

7474
@apply_configs
75-
def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements
75+
def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements,too-many-branches
7676
df: pd.DataFrame,
77-
path: str,
77+
path: Optional[str] = None,
7878
sep: str = ",",
7979
index: bool = True,
8080
columns: Optional[List[str]] = None,
@@ -137,8 +137,9 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
137137
----------
138138
df: pandas.DataFrame
139139
Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html
140-
path : str
141-
Amazon S3 path (e.g. s3://bucket/filename.csv).
140+
path : str, optional
141+
Amazon S3 path (e.g. s3://bucket/prefix/filename.csv) (for dataset e.g. ``s3://bucket/prefix``).
142+
Required if dataset=False or when creating a new dataset
142143
sep : str
143144
String of length 1. Field delimiter for the output file.
144145
index : bool
@@ -414,13 +415,27 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
414415
catalog_table_input = catalog._get_table_input( # pylint: disable=protected-access
415416
database=database, table=table, boto3_session=session, catalog_id=catalog_id
416417
)
418+
catalog_path = catalog_table_input["StorageDescriptor"]["Location"] if catalog_table_input else None
419+
if path is None:
420+
if catalog_path:
421+
path = catalog_path
422+
else:
423+
raise exceptions.InvalidArgumentValue(
424+
"Glue table does not exist in the catalog. Please pass the `path` argument to create it."
425+
)
426+
elif path and catalog_path:
427+
if path.rstrip("/") != catalog_path.rstrip("/"):
428+
raise exceptions.InvalidArgumentValue(
429+
f"The specified path: {path}, does not match the existing Glue catalog table path: {catalog_path}"
430+
)
417431
if pandas_kwargs.get("compression") not in ("gzip", "bz2", None):
418432
raise exceptions.InvalidArgumentCombination(
419433
"If database and table are given, you must use one of these compressions: gzip, bz2 or None."
420434
)
421435

422436
df = _apply_dtype(df=df, dtype=dtype, catalog_table_input=catalog_table_input, mode=mode)
423437

438+
paths: List[str] = []
424439
if dataset is False:
425440
pandas_kwargs["sep"] = sep
426441
pandas_kwargs["index"] = index
@@ -434,7 +449,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
434449
s3_additional_kwargs=s3_additional_kwargs,
435450
**pandas_kwargs,
436451
)
437-
paths = [path]
452+
paths = [path] # type: ignore
438453
else:
439454
if database and table:
440455
quoting: Optional[int] = csv.QUOTE_NONE
@@ -461,7 +476,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
461476
func=_to_text,
462477
concurrent_partitioning=concurrent_partitioning,
463478
df=df,
464-
path_root=path,
479+
path_root=path, # type: ignore
465480
index=index,
466481
sep=sep,
467482
compression=compression,
@@ -486,7 +501,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
486501
catalog._create_csv_table( # pylint: disable=protected-access
487502
database=database,
488503
table=table,
489-
path=path,
504+
path=path, # type: ignore
490505
columns_types=columns_types,
491506
partitions_types=partitions_types,
492507
bucketing_info=bucketing_info,

tests/test__routines.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part
4444
df = pd.DataFrame({"c1": [None, 1, None]}, dtype="Int16")
4545
wr.s3.to_parquet(
4646
df=df,
47-
path=path,
4847
dataset=True,
4948
mode="overwrite",
5049
database=glue_database,
@@ -101,7 +100,6 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part
101100
df = pd.DataFrame({"c2": ["a", None, "b"], "c1": [None, None, None]})
102101
wr.s3.to_parquet(
103102
df=df,
104-
path=path,
105103
dataset=True,
106104
mode="append",
107105
database=glue_database,
@@ -162,7 +160,6 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part
162160
df = pd.DataFrame({"c0": ["foo", None], "c1": [0, 1]})
163161
wr.s3.to_parquet(
164162
df=df,
165-
path=path,
166163
dataset=True,
167164
mode="overwrite",
168165
database=glue_database,
@@ -223,7 +220,6 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part
223220
df = pd.DataFrame({"c0": [1, 2], "c1": ["1", "3"], "c2": [True, False]})
224221
wr.s3.to_parquet(
225222
df=df,
226-
path=path,
227223
dataset=True,
228224
mode="overwrite_partitions",
229225
database=glue_database,

tests/test_athena_csv.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def test_to_csv_modes(glue_database, glue_table, path, use_threads, concurrent_p
4949
df = pd.DataFrame({"c1": [0, 1, 2]}, dtype="Int16")
5050
wr.s3.to_csv(
5151
df=df,
52-
path=path,
5352
dataset=True,
5453
mode="overwrite",
5554
database=glue_database,
@@ -106,7 +105,6 @@ def test_to_csv_modes(glue_database, glue_table, path, use_threads, concurrent_p
106105
df = pd.DataFrame({"c0": ["foo", "boo"], "c1": [0, 1]})
107106
wr.s3.to_csv(
108107
df=df,
109-
path=path,
110108
dataset=True,
111109
mode="overwrite",
112110
database=glue_database,

tests/test_s3.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,23 @@ def mock_make_api_call(self, operation_name, kwarg):
103103
wr.s3.delete_objects(path=[path])
104104

105105

106+
def test_missing_or_wrong_path(path, glue_database, glue_table):
107+
# Missing path
108+
df = pd.DataFrame({"FooBoo": [1, 2, 3]})
109+
with pytest.raises(wr.exceptions.InvalidArgumentValue):
110+
wr.s3.to_parquet(df=df)
111+
with pytest.raises(wr.exceptions.InvalidArgumentCombination):
112+
wr.s3.to_parquet(df=df, dataset=True)
113+
with pytest.raises(wr.exceptions.InvalidArgumentValue):
114+
wr.s3.to_parquet(df=df, dataset=True, database=glue_database, table=glue_table)
115+
116+
# Wrong path
117+
wr.s3.to_parquet(df=df, path=path, dataset=True, database=glue_database, table=glue_table)
118+
wrong_path = "s3://bucket/prefix"
119+
with pytest.raises(wr.exceptions.InvalidArgumentValue):
120+
wr.s3.to_parquet(df=df, path=wrong_path, dataset=True, database=glue_database, table=glue_table)
121+
122+
106123
def test_s3_empty_dfs():
107124
df = pd.DataFrame()
108125
with pytest.raises(wr.exceptions.EmptyDataFrame):

0 commit comments

Comments
 (0)