Skip to content

Commit 2d73100

Browse files
Expand SQL formatter to LakeFormation (#1684)
1 parent 4f4a73c commit 2d73100

File tree

7 files changed

+133
-41
lines changed

7 files changed

+133
-41
lines changed

awswrangler/athena/_formatter.py renamed to awswrangler/_sql_formatter.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
"""Formatting logic for Athena parameters."""
1+
"""Formatting logic for SQL parameters."""
22
import datetime
33
import decimal
4+
import re
45
from enum import Enum
5-
from typing import Any, Dict, Generic, Sequence, Type, TypeVar
6+
from typing import Any, Dict, Generic, Optional, Sequence, Type, TypeVar
67

78

89
class _EngineType(Enum):
910
PRESTO = "presto"
1011
HIVE = "hive"
12+
PARTIQL = "partiql"
1113

1214
def __str__(self) -> str:
1315
return self.value
@@ -29,14 +31,17 @@ def __str__(self) -> str:
2931

3032
class _NullType(_AbstractType[_NoneType]):
3133
def __str__(self) -> str:
34+
if self.engine == _EngineType.PARTIQL:
35+
return "null"
36+
3237
return "NULL"
3338

3439

3540
class _StringType(_AbstractType[str]):
3641
supported_formats = {"s", "i"}
3742

3843
def __str__(self) -> str:
39-
if self.engine == _EngineType.PRESTO:
44+
if self.engine in [_EngineType.PRESTO, _EngineType.PARTIQL]:
4045
return f"""'{self.data.replace("'", "''")}'"""
4146

4247
if self.engine == _EngineType.HIVE:
@@ -53,6 +58,9 @@ def __str__(self) -> str:
5358

5459
class _BooleanType(_AbstractType[bool]):
5560
def __str__(self) -> str:
61+
if self.engine == _EngineType.PARTIQL:
62+
return "1" if self.data else "0"
63+
5664
return str(self.data).upper()
5765

5866

@@ -68,28 +76,44 @@ def __str__(self) -> str:
6876

6977
class _DecimalType(_AbstractType[decimal.Decimal]):
7078
def __str__(self) -> str:
79+
if self.engine == _EngineType.PARTIQL:
80+
return f"'{self.data}'"
81+
7182
return f"DECIMAL '{self.data:f}'"
7283

7384

7485
class _TimestampType(_AbstractType[datetime.datetime]):
7586
def __str__(self) -> str:
7687
if self.data.tzinfo is not None:
7788
raise TypeError(f"Supports only timezone aware datatype, got {self.data}.")
89+
90+
if self.engine == _EngineType.PARTIQL:
91+
return f"'{self.data.isoformat()}'"
92+
7893
return f"TIMESTAMP '{self.data.isoformat(sep=' ', timespec='milliseconds')}'"
7994

8095

8196
class _DateType(_AbstractType[datetime.date]):
8297
def __str__(self) -> str:
98+
if self.engine == _EngineType.PARTIQL:
99+
return f"'{self.data.isoformat()}'"
100+
83101
return f"DATE '{self.data.isoformat()}'"
84102

85103

86104
class _ArrayType(_AbstractType[Sequence[_PythonType]]):
87105
def __str__(self) -> str:
106+
if self.engine == _EngineType.PARTIQL:
107+
super().__str__()
108+
88109
return f"ARRAY [{', '.join(map(str, self.data))}]"
89110

90111

91112
class _MapType(_AbstractType[Dict[_PythonType, _PythonTypeMapValue]]):
92113
def __str__(self) -> str:
114+
if self.engine == _EngineType.PARTIQL:
115+
super().__str__()
116+
93117
if not self.data:
94118
return "MAP()"
95119

@@ -165,3 +189,26 @@ def _format_parameters(params: Dict[str, Any], engine: _EngineType) -> Dict[str,
165189
processed_params[k] = str(abs_type)
166190

167191
return processed_params
192+
193+
194+
_PATTERN = re.compile(r":([A-Za-z0-9_]+)(?![A-Za-z0-9_])")
195+
196+
197+
def _process_sql_params(sql: str, params: Optional[Dict[str, Any]], engine: _EngineType = _EngineType.PRESTO) -> str:
198+
if params is None:
199+
params = {}
200+
201+
processed_params = _format_parameters(params, engine=engine)
202+
203+
def replace(match: re.Match) -> str: # type: ignore
204+
key = match.group(1)
205+
206+
if key not in processed_params:
207+
# do not replace anything if the parameter is not provided
208+
return str(match.group(0))
209+
210+
return str(processed_params[key])
211+
212+
sql = _PATTERN.sub(replace, sql)
213+
214+
return sql

awswrangler/athena/_read.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import csv
44
import logging
5-
import re
65
import sys
76
import uuid
87
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
@@ -14,7 +13,7 @@
1413
from awswrangler import _utils, catalog, exceptions, s3
1514
from awswrangler._config import apply_configs
1615
from awswrangler._data_types import cast_pandas_with_athena_types
17-
from awswrangler.athena._formatter import _EngineType, _format_parameters
16+
from awswrangler._sql_formatter import _process_sql_params
1817
from awswrangler.athena._utils import (
1918
_apply_query_metadata,
2019
_empty_dataframe_response,
@@ -568,29 +567,6 @@ def _unload(
568567
return query_metadata
569568

570569

571-
_PATTERN = re.compile(r":([A-Za-z0-9_]+)(?![A-Za-z0-9_])")
572-
573-
574-
def _process_sql_params(sql: str, params: Optional[Dict[str, Any]]) -> str:
575-
if params is None:
576-
params = {}
577-
578-
processed_params = _format_parameters(params, engine=_EngineType.PRESTO)
579-
580-
def replace(match: re.Match) -> str: # type: ignore
581-
key = match.group(1)
582-
583-
if key not in processed_params:
584-
# do not replace anything if the parameter is not provided
585-
return str(match.group(0))
586-
587-
return str(processed_params[key])
588-
589-
sql = _PATTERN.sub(replace, sql)
590-
591-
return sql
592-
593-
594570
@apply_configs
595571
def get_query_results(
596572
query_execution_id: str,
@@ -922,7 +898,7 @@ def read_sql_query(
922898
>>> import awswrangler as wr
923899
>>> df = wr.athena.read_sql_query(
924900
... sql="SELECT * FROM my_table WHERE name=:name AND city=:city",
925-
... params={"name": "'filtered_name'", "city": "'filtered_city'"}
901+
... params={"name": "filtered_name", "city": "filtered_city"}
926902
... )
927903
928904
"""
@@ -1303,7 +1279,7 @@ def unload(
13031279
>>> import awswrangler as wr
13041280
>>> res = wr.athena.unload(
13051281
... sql="SELECT * FROM my_table WHERE name=:name AND city=:city",
1306-
... params={"name": "'filtered_name'", "city": "'filtered_city'"}
1282+
... params={"name": "filtered_name", "city": "filtered_city"}
13071283
... )
13081284
13091285
"""

awswrangler/athena/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from awswrangler import _data_types, _utils, catalog, exceptions, s3, sts
1818
from awswrangler._config import apply_configs
19-
from awswrangler.athena._formatter import _EngineType, _format_parameters
19+
from awswrangler._sql_formatter import _EngineType, _format_parameters
2020
from awswrangler.catalog._utils import _catalog_id, _transaction_id
2121

2222
from ._cache import _cache_manager, _CacheInfo, _check_for_cached_results, _LocalMetadataCacheManager

awswrangler/lakeformation/_read.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from awswrangler import _data_types, _utils, catalog
1111
from awswrangler._config import apply_configs
1212
from awswrangler._distributed import engine
13+
from awswrangler._sql_formatter import _EngineType, _process_sql_params
1314
from awswrangler._threading import _get_executor
1415
from awswrangler.catalog._utils import _catalog_id, _transaction_id
1516
from awswrangler.lakeformation._utils import commit_transaction, start_transaction, wait_query
@@ -157,17 +158,15 @@ def read_sql_query(
157158
... sql="SELECT * FROM my_table WHERE name=:name AND city=:city",
158159
... database="my_db",
159160
... query_as_of_time="1611142914",
160-
... params={"name": "'filtered_name'", "city": "'filtered_city'"}
161+
... params={"name": "filtered_name", "city": "filtered_city"}
161162
... )
162163
163164
"""
164165
session: boto3.Session = _utils.ensure_session(session=boto3_session)
165166
client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session)
166167
commit_trans: bool = False
167-
if params is None:
168-
params = {}
169-
for key, value in params.items():
170-
sql = sql.replace(f":{key};", str(value))
168+
169+
sql = _process_sql_params(sql, params, engine=_EngineType.PARTIQL)
171170

172171
if not any([transaction_id, query_as_of_time]):
173172
_logger.debug("Neither `transaction_id` nor `query_as_of_time` were specified, starting transaction")

tests/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def get_df_list(governed=False):
105105
df["category"] = df["category"].astype("category")
106106

107107
if governed:
108-
df = (df.drop(["iint8", "binary"], axis=1),) # tinyint & binary currently not supported
108+
df = df.drop(["iint8", "binary"], axis=1) # tinyint & binary currently not supported
109109
return df
110110

111111

tests/unit/test_lakeformation.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import calendar
2+
import datetime as dt
23
import logging
34
import time
5+
from decimal import Decimal
46

57
import pytest
68

79
import awswrangler as wr
810
from awswrangler._distributed import EngineEnum, MemoryFormatEnum
911

10-
from .._utils import ensure_data_types, ensure_data_types_csv, get_df, get_df_csv
12+
from .._utils import ensure_data_types, ensure_data_types_csv, get_df, get_df_csv, get_df_list
1113

1214
if wr.engine.get() == EngineEnum.RAY and wr.memory_format.get() == MemoryFormatEnum.MODIN:
1315
import modin.pandas as pd
@@ -50,9 +52,9 @@ def test_lakeformation(path, path2, glue_database, glue_table, glue_table2, use_
5052

5153
# Filter query
5254
df2 = wr.lakeformation.read_sql_query(
53-
sql=f"SELECT * FROM {glue_table} WHERE iint16 = :iint16;",
55+
sql=f'SELECT * FROM {glue_table} WHERE "string" = :city_name',
5456
database=glue_database,
55-
params={"iint16": 1},
57+
params={"city_name": "Washington"},
5658
)
5759
assert len(df2.index) == 1
5860

@@ -145,3 +147,71 @@ def test_lakeformation_multi_transaction(path, path2, glue_database, glue_table,
145147

146148
assert df2.shape == df4.shape
147149
assert df2.c1.sum() == df4.c1.sum()
150+
151+
152+
@pytest.mark.parametrize(
153+
"col_name,col_value",
154+
[
155+
("date", dt.date(2020, 1, 1)),
156+
("timestamp", dt.datetime(2020, 1, 1)),
157+
("bool", True),
158+
("decimal", Decimal(("1.99"))),
159+
("float", 0.0),
160+
("iint16", 1),
161+
],
162+
)
163+
def test_lakeformation_partiql_formatting(path, path2, glue_database, glue_table, glue_table2, col_name, col_value):
164+
wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table)
165+
wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table2)
166+
167+
wr.s3.to_parquet(
168+
df=get_df_list(governed=True),
169+
path=path,
170+
index=False,
171+
boto3_session=None,
172+
s3_additional_kwargs=None,
173+
dataset=True,
174+
partition_cols=["par0", "par1"],
175+
mode="overwrite",
176+
table=glue_table,
177+
table_type="GOVERNED",
178+
database=glue_database,
179+
)
180+
181+
# Filter query
182+
df = wr.lakeformation.read_sql_query(
183+
sql=f'SELECT * FROM {glue_table} WHERE "{col_name}" = :col_value',
184+
database=glue_database,
185+
params={"col_value": col_value},
186+
)
187+
assert len(df) == 1
188+
189+
190+
def test_lakeformation_partiql_formatting_escape_string(path, path2, glue_database, glue_table, glue_table2):
191+
df = pd.DataFrame(
192+
{
193+
"id": [1, 2, 3],
194+
"string": ["normal string", "'weird' string", "another normal string"],
195+
}
196+
)
197+
198+
wr.s3.to_parquet(
199+
df=df,
200+
path=path,
201+
index=False,
202+
boto3_session=None,
203+
s3_additional_kwargs=None,
204+
dataset=True,
205+
mode="overwrite",
206+
table=glue_table,
207+
table_type="GOVERNED",
208+
database=glue_database,
209+
)
210+
211+
# Filter query
212+
df = wr.lakeformation.read_sql_query(
213+
sql=f'SELECT * FROM {glue_table} WHERE "string" = :col_value',
214+
database=glue_database,
215+
params={"col_value": "'weird' string"},
216+
)
217+
assert len(df) == 1

tests/unit/test_athena_params_formatter.py renamed to tests/unit/test_sql_params_formatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytest
66

7-
from awswrangler.athena._formatter import _EngineType, _format_parameters
7+
from awswrangler._sql_formatter import _EngineType, _format_parameters
88

99

1010
@pytest.mark.parametrize("engine", [_EngineType.HIVE, _EngineType.PRESTO])

0 commit comments

Comments
 (0)