Skip to content

Commit b3bac5c

Browse files
authored
Add storage location template for glue (#1023)
* Add storage location template for glue * Linting * Remove storage location template for writing * Trigger CI * Fix tests
1 parent bae604e commit b3bac5c

File tree

5 files changed

+65
-2
lines changed

5 files changed

+65
-2
lines changed

awswrangler/catalog/_create.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def _create_table( # pylint: disable=too-many-branches,too-many-statements
4242
projection_values: Optional[Dict[str, str]],
4343
projection_intervals: Optional[Dict[str, str]],
4444
projection_digits: Optional[Dict[str, str]],
45+
projection_storage_location_template: Optional[str],
4546
catalog_id: Optional[str],
4647
) -> None:
4748
# Description
@@ -71,7 +72,7 @@ def _create_table( # pylint: disable=too-many-branches,too-many-statements
7172
projection_digits = {sanitize_column_name(k): v for k, v in projection_digits.items()}
7273
for k, v in projection_types.items():
7374
dtype: Optional[str] = partitions_types.get(k)
74-
if dtype is None:
75+
if dtype is None and projection_storage_location_template is None:
7576
raise exceptions.InvalidArgumentCombination(
7677
f"Column {k} appears as projected column but not as partitioned column."
7778
)
@@ -95,6 +96,12 @@ def _create_table( # pylint: disable=too-many-branches,too-many-statements
9596
mode = _update_if_necessary(
9697
dic=table_input["Parameters"], key=f"projection.{k}.digits", value=str(v), mode=mode
9798
)
99+
mode = _update_if_necessary(
100+
table_input["Parameters"],
101+
key="storage.location.template",
102+
value=projection_storage_location_template,
103+
mode=mode,
104+
)
98105
else:
99106
table_input["Parameters"]["projection.enabled"] = "false"
100107

@@ -232,6 +239,7 @@ def _create_parquet_table(
232239
projection_values: Optional[Dict[str, str]],
233240
projection_intervals: Optional[Dict[str, str]],
234241
projection_digits: Optional[Dict[str, str]],
242+
projection_storage_location_template: Optional[str],
235243
boto3_session: Optional[boto3.Session],
236244
catalog_table_input: Optional[Dict[str, Any]],
237245
) -> None:
@@ -280,6 +288,7 @@ def _create_parquet_table(
280288
projection_values=projection_values,
281289
projection_intervals=projection_intervals,
282290
projection_digits=projection_digits,
291+
projection_storage_location_template=projection_storage_location_template,
283292
catalog_id=catalog_id,
284293
)
285294

@@ -309,6 +318,7 @@ def _create_csv_table( # pylint: disable=too-many-arguments
309318
projection_values: Optional[Dict[str, str]],
310319
projection_intervals: Optional[Dict[str, str]],
311320
projection_digits: Optional[Dict[str, str]],
321+
projection_storage_location_template: Optional[str],
312322
catalog_table_input: Optional[Dict[str, Any]],
313323
catalog_id: Optional[str],
314324
) -> None:
@@ -353,6 +363,7 @@ def _create_csv_table( # pylint: disable=too-many-arguments
353363
projection_values=projection_values,
354364
projection_intervals=projection_intervals,
355365
projection_digits=projection_digits,
366+
projection_storage_location_template=projection_storage_location_template,
356367
catalog_id=catalog_id,
357368
)
358369

@@ -380,6 +391,7 @@ def _create_json_table( # pylint: disable=too-many-arguments
380391
projection_values: Optional[Dict[str, str]],
381392
projection_intervals: Optional[Dict[str, str]],
382393
projection_digits: Optional[Dict[str, str]],
394+
projection_storage_location_template: Optional[str],
383395
catalog_table_input: Optional[Dict[str, Any]],
384396
catalog_id: Optional[str],
385397
) -> None:
@@ -422,6 +434,7 @@ def _create_json_table( # pylint: disable=too-many-arguments
422434
projection_values=projection_values,
423435
projection_intervals=projection_intervals,
424436
projection_digits=projection_digits,
437+
projection_storage_location_template=projection_storage_location_template,
425438
catalog_id=catalog_id,
426439
)
427440

@@ -613,6 +626,7 @@ def create_parquet_table(
613626
projection_values: Optional[Dict[str, str]] = None,
614627
projection_intervals: Optional[Dict[str, str]] = None,
615628
projection_digits: Optional[Dict[str, str]] = None,
629+
projection_storage_location_template: Optional[str] = None,
616630
boto3_session: Optional[boto3.Session] = None,
617631
) -> None:
618632
"""Create a Parquet Table (Metadata Only) in the AWS Glue Catalog.
@@ -673,6 +687,11 @@ def create_parquet_table(
673687
Dictionary of partitions names and Athena projections digits.
674688
https://docs.aws.amazon.com/athena/latest/ug/partition-projection-supported-types.html
675689
(e.g. {'col_name': '1', 'col2_name': '2'})
690+
projection_storage_location_template: Optional[str]
691+
Value which is allows Athena to properly map partition values if the S3 file locations do not follow
692+
a typical `.../column=value/...` pattern.
693+
https://docs.aws.amazon.com/athena/latest/ug/partition-projection-setting-up.html
694+
(e.g. s3://bucket/table_root/a=${a}/${b}/some_static_subdirectory/${c}/)
676695
boto3_session : boto3.Session(), optional
677696
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
678697
@@ -721,13 +740,14 @@ def create_parquet_table(
721740
projection_values=projection_values,
722741
projection_intervals=projection_intervals,
723742
projection_digits=projection_digits,
743+
projection_storage_location_template=projection_storage_location_template,
724744
boto3_session=boto3_session,
725745
catalog_table_input=catalog_table_input,
726746
)
727747

728748

729749
@apply_configs
730-
def create_csv_table(
750+
def create_csv_table( # pylint: disable=too-many-arguments
731751
database: str,
732752
table: str,
733753
path: str,
@@ -752,6 +772,7 @@ def create_csv_table(
752772
projection_values: Optional[Dict[str, str]] = None,
753773
projection_intervals: Optional[Dict[str, str]] = None,
754774
projection_digits: Optional[Dict[str, str]] = None,
775+
projection_storage_location_template: Optional[str] = None,
755776
catalog_id: Optional[str] = None,
756777
) -> None:
757778
r"""Create a CSV Table (Metadata Only) in the AWS Glue Catalog.
@@ -825,6 +846,11 @@ def create_csv_table(
825846
Dictionary of partitions names and Athena projections digits.
826847
https://docs.aws.amazon.com/athena/latest/ug/partition-projection-supported-types.html
827848
(e.g. {'col_name': '1', 'col2_name': '2'})
849+
projection_storage_location_template: Optional[str]
850+
Value which is allows Athena to properly map partition values if the S3 file locations do not follow
851+
a typical `.../column=value/...` pattern.
852+
https://docs.aws.amazon.com/athena/latest/ug/partition-projection-setting-up.html
853+
(e.g. s3://bucket/table_root/a=${a}/${b}/some_static_subdirectory/${c}/)
828854
boto3_session : boto3.Session(), optional
829855
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
830856
catalog_id : str, optional
@@ -877,6 +903,7 @@ def create_csv_table(
877903
projection_values=projection_values,
878904
projection_intervals=projection_intervals,
879905
projection_digits=projection_digits,
906+
projection_storage_location_template=projection_storage_location_template,
880907
boto3_session=boto3_session,
881908
catalog_table_input=catalog_table_input,
882909
sep=sep,
@@ -910,6 +937,7 @@ def create_json_table(
910937
projection_values: Optional[Dict[str, str]] = None,
911938
projection_intervals: Optional[Dict[str, str]] = None,
912939
projection_digits: Optional[Dict[str, str]] = None,
940+
projection_storage_location_template: Optional[str] = None,
913941
catalog_id: Optional[str] = None,
914942
) -> None:
915943
r"""Create a JSON Table (Metadata Only) in the AWS Glue Catalog.
@@ -979,6 +1007,11 @@ def create_json_table(
9791007
Dictionary of partitions names and Athena projections digits.
9801008
https://docs.aws.amazon.com/athena/latest/ug/partition-projection-supported-types.html
9811009
(e.g. {'col_name': '1', 'col2_name': '2'})
1010+
projection_storage_location_template: Optional[str]
1011+
Value which is allows Athena to properly map partition values if the S3 file locations do not follow
1012+
a typical `.../column=value/...` pattern.
1013+
https://docs.aws.amazon.com/athena/latest/ug/partition-projection-setting-up.html
1014+
(e.g. s3://bucket/table_root/a=${a}/${b}/some_static_subdirectory/${c}/)
9821015
boto3_session : boto3.Session(), optional
9831016
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
9841017
catalog_id : str, optional
@@ -1030,6 +1063,7 @@ def create_json_table(
10301063
projection_values=projection_values,
10311064
projection_intervals=projection_intervals,
10321065
projection_digits=projection_digits,
1066+
projection_storage_location_template=projection_storage_location_template,
10331067
boto3_session=boto3_session,
10341068
catalog_table_input=catalog_table_input,
10351069
serde_library=serde_library,

awswrangler/s3/_write_parquet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
606606
projection_values=projection_values,
607607
projection_intervals=projection_intervals,
608608
projection_digits=projection_digits,
609+
projection_storage_location_template=None,
609610
catalog_id=catalog_id,
610611
catalog_table_input=catalog_table_input,
611612
)

awswrangler/s3/_write_text.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
538538
projection_values=projection_values,
539539
projection_intervals=projection_intervals,
540540
projection_digits=projection_digits,
541+
projection_storage_location_template=None,
541542
catalog_table_input=catalog_table_input,
542543
catalog_id=catalog_id,
543544
compression=pandas_kwargs.get("compression"),
@@ -888,6 +889,7 @@ def to_json( # pylint: disable=too-many-arguments,too-many-locals,too-many-stat
888889
projection_values=projection_values,
889890
projection_intervals=projection_intervals,
890891
projection_digits=projection_digits,
892+
projection_storage_location_template=None,
891893
catalog_table_input=catalog_table_input,
892894
catalog_id=catalog_id,
893895
compression=pandas_kwargs.get("compression"),

tests/test_athena_projection.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,28 @@ def test_to_parquet_projection_injected(glue_database, glue_table, path):
9494
df2 = wr.athena.read_sql_query(f"SELECT * FROM {glue_table} WHERE c1='foo' AND c2='0'", glue_database)
9595
assert df2.shape == (1, 3)
9696
assert df2.c0.iloc[0] == 0
97+
98+
99+
def test_to_parquet_storage_location(glue_database, glue_table, path):
100+
df1 = pd.DataFrame({"c0": [0], "c1": ["foo"], "c2": ["0"]})
101+
df2 = pd.DataFrame({"c0": [1], "c1": ["foo"], "c2": ["1"]})
102+
df3 = pd.DataFrame({"c0": [2], "c1": ["boo"], "c2": ["2"]})
103+
df4 = pd.DataFrame({"c0": [3], "c1": ["boo"], "c2": ["3"]})
104+
105+
wr.s3.to_parquet(df=df1, path=f"{path}foo/0/file0.parquet")
106+
wr.s3.to_parquet(df=df2, path=f"{path}foo/1/file1.parquet")
107+
wr.s3.to_parquet(df=df3, path=f"{path}boo/2/file2.parquet")
108+
wr.s3.to_parquet(df=df4, path=f"{path}boo/3/file3.parquet")
109+
column_types, partitions_types = wr.catalog.extract_athena_types(df1)
110+
wr.catalog.create_parquet_table(
111+
database=glue_database,
112+
table=glue_table,
113+
path=path,
114+
columns_types=column_types,
115+
projection_enabled=True,
116+
projection_types={"c1": "injected", "c2": "injected"},
117+
projection_storage_location_template=f"{path}${{c1}}/${{c2}}",
118+
)
119+
120+
df5 = wr.athena.read_sql_query(f"SELECT * FROM {glue_table} WHERE c1='foo' AND c2='0'", glue_database)
121+
pd.testing.assert_frame_equal(df1, df5, check_dtype=False)

tests/test_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def test_basics(path, glue_database, glue_table, workgroup0, workgroup1):
126126

127127

128128
def test_athena_cache_configuration():
129+
wr.config.max_remote_cache_entries = 50
129130
wr.config.max_local_cache_entries = 20
130131
assert wr.config.max_remote_cache_entries == 20
131132

0 commit comments

Comments
 (0)