Skip to content

Commit c4d9b70

Browse files
refactor: SQL formatter (#2288)
1 parent e20faf7 commit c4d9b70

File tree

3 files changed

+135
-157
lines changed

3 files changed

+135
-157
lines changed

awswrangler/_sql_formatter.py

Lines changed: 117 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -2,202 +2,177 @@
22
import datetime
33
import decimal
44
import re
5-
from enum import Enum
6-
from typing import Any, Dict, Generic, Optional, Sequence, Type, TypeVar
5+
from abc import ABC, abstractmethod
6+
from typing import Any, Callable, Dict, Optional, Sequence, Type
77

8+
from typing_extensions import Literal
89

9-
class _EngineType(Enum):
10-
PRESTO = "presto"
11-
HIVE = "hive"
12-
PARTIQL = "partiql"
10+
from awswrangler import exceptions
1311

14-
def __str__(self) -> str:
15-
return self.value
12+
_EngineTypeLiteral = Literal["presto", "hive", "partiql"]
1613

1714

18-
_NoneType = type(None)
19-
_PythonType = TypeVar("_PythonType")
20-
_PythonTypeMapValue = TypeVar("_PythonTypeMapValue")
15+
class _Engine(ABC):
16+
def __init__(self, engine_name: _EngineTypeLiteral) -> None:
17+
self.engine_name = engine_name
2118

19+
def format_null(self, value: None = None) -> str:
20+
return "NULL"
2221

23-
class _AbstractType(Generic[_PythonType]):
24-
def __init__(self, data: _PythonType, engine: _EngineType):
25-
self.data: _PythonType = data
26-
self.engine: _EngineType = engine
22+
@abstractmethod
23+
def format_string(self, value: str) -> str:
24+
pass
2725

28-
def __str__(self) -> str:
29-
raise NotImplementedError(f"{type(self)} not implemented for engine={self.engine}.")
26+
def format_bool(self, value: bool) -> str:
27+
return str(value).upper()
3028

29+
def format_integer(self, value: int) -> str:
30+
return str(value)
3131

32-
class _NullType(_AbstractType[_NoneType]):
33-
def __str__(self) -> str:
34-
if self.engine == _EngineType.PARTIQL:
35-
return "null"
32+
def format_float(self, value: float) -> str:
33+
return f"{value:f}"
3634

37-
return "NULL"
35+
def format_decimal(self, value: decimal.Decimal) -> str:
36+
return f"DECIMAL '{value:f}'"
3837

38+
def format_timestamp(self, value: datetime.datetime) -> str:
39+
if value.tzinfo is not None:
40+
raise TypeError(f"Supports only timezone aware datatype, got {value}.")
3941

40-
class _StringType(_AbstractType[str]):
41-
supported_formats = {"s", "i"}
42+
return f"TIMESTAMP '{value.isoformat(sep=' ', timespec='milliseconds')}'"
4243

43-
def __str__(self) -> str:
44-
if self.engine in [_EngineType.PRESTO, _EngineType.PARTIQL]:
45-
return f"""'{self.data.replace("'", "''")}'"""
44+
def format_date(self, value: datetime.date) -> str:
45+
return f"DATE '{value.isoformat()}'"
4646

47-
if self.engine == _EngineType.HIVE:
48-
return "'{}'".format(
49-
self.data.replace("\\", "\\\\")
50-
.replace("'", "\\'")
51-
.replace("\r", "\\r")
52-
.replace("\n", "\\n")
53-
.replace("\t", "\\t")
54-
)
47+
def format_array(self, value: Sequence[Any]) -> str:
48+
return f"ARRAY [{', '.join(map(self.format, value))}]"
5549

56-
return super().__str__()
50+
def format_dict(self, value: Dict[Any, Any]) -> str:
51+
if not value:
52+
return "MAP()"
5753

54+
map_keys = list(value.keys())
55+
key_type = type(map_keys[0])
56+
for key in map_keys:
57+
if key is None:
58+
raise TypeError("Map key cannot be null.")
59+
if not isinstance(key, key_type):
60+
raise TypeError("All Map key elements must be the same type.")
5861

59-
class _BooleanType(_AbstractType[bool]):
60-
def __str__(self) -> str:
61-
if self.engine == _EngineType.PARTIQL:
62-
return "1" if self.data else "0"
62+
map_values = list(value.values())
63+
return (
64+
f"MAP(ARRAY [{', '.join(map(self.format, map_keys))}], ARRAY [{', '.join(map(self.format, map_values))}])"
65+
)
6366

64-
return str(self.data).upper()
67+
def format(self, data: Any) -> str:
68+
formats_dict: Dict[Type[Any], Callable[[Any], str]] = {
69+
bool: self.format_bool,
70+
str: self.format_string,
71+
int: self.format_integer,
72+
datetime.datetime: self.format_timestamp,
73+
datetime.date: self.format_date,
74+
decimal.Decimal: self.format_decimal,
75+
float: self.format_float,
76+
list: self.format_array,
77+
tuple: self.format_array,
78+
set: self.format_array,
79+
dict: self.format_dict,
80+
}
6581

82+
if data is None:
83+
return self.format_null()
6684

67-
class _IntegerType(_AbstractType[int]):
68-
def __str__(self) -> str:
69-
return str(self.data)
85+
for python_type, format_func in formats_dict.items():
86+
if isinstance(data, python_type):
87+
return format_func(data)
7088

89+
raise TypeError(f"Unsupported type {type(data)} in parameter.")
7190

72-
class _FloatType(_AbstractType[float]):
73-
def __str__(self) -> str:
74-
return f"{self.data:f}"
7591

92+
class _PrestoEngine(_Engine):
93+
def __init__(self) -> None:
94+
super().__init__("presto")
7695

77-
class _DecimalType(_AbstractType[decimal.Decimal]):
78-
def __str__(self) -> str:
79-
if self.engine == _EngineType.PARTIQL:
80-
return f"'{self.data}'"
96+
def format_string(self, value: str) -> str:
97+
return f"""'{value.replace("'", "''")}'"""
8198

82-
return f"DECIMAL '{self.data:f}'"
8399

100+
class _HiveEngine(_Engine):
101+
def __init__(self) -> None:
102+
super().__init__("hive")
84103

85-
class _TimestampType(_AbstractType[datetime.datetime]):
86-
def __str__(self) -> str:
87-
if self.data.tzinfo is not None:
88-
raise TypeError(f"Supports only timezone aware datatype, got {self.data}.")
104+
def format_string(self, value: str) -> str:
105+
return "'{}'".format(
106+
value.replace("\\", "\\\\")
107+
.replace("'", "\\'")
108+
.replace("\r", "\\r")
109+
.replace("\n", "\\n")
110+
.replace("\t", "\\t")
111+
)
89112

90-
if self.engine == _EngineType.PARTIQL:
91-
return f"'{self.data.isoformat()}'"
92113

93-
return f"TIMESTAMP '{self.data.isoformat(sep=' ', timespec='milliseconds')}'"
114+
class _PartiQLEngine(_Engine):
115+
def __init__(self) -> None:
116+
super().__init__("partiql")
94117

118+
def format_null(self, value: None = None) -> str:
119+
return "null"
95120

96-
class _DateType(_AbstractType[datetime.date]):
97-
def __str__(self) -> str:
98-
if self.engine == _EngineType.PARTIQL:
99-
return f"'{self.data.isoformat()}'"
121+
def format_string(self, value: str) -> str:
122+
return f"""'{value.replace("'", "''")}'"""
100123

101-
return f"DATE '{self.data.isoformat()}'"
124+
def format_bool(self, value: bool) -> str:
125+
return "1" if value else "0"
102126

127+
def format_decimal(self, value: decimal.Decimal) -> str:
128+
return f"'{value}'"
103129

104-
class _ArrayType(_AbstractType[Sequence[_PythonType]]):
105-
def __str__(self) -> str:
106-
if self.engine == _EngineType.PARTIQL:
107-
super().__str__()
130+
def format_timestamp(self, value: datetime.datetime) -> str:
131+
if value.tzinfo is not None:
132+
raise TypeError(f"Supports only timezone aware datatype, got {value}.")
108133

109-
return f"ARRAY [{', '.join(map(str, self.data))}]"
134+
return f"'{value.isoformat()}'"
110135

136+
def format_date(self, value: datetime.date) -> str:
137+
return f"'{value.isoformat()}'"
111138

112-
class _MapType(_AbstractType[Dict[_PythonType, _PythonTypeMapValue]]):
113-
def __str__(self) -> str:
114-
if self.engine == _EngineType.PARTIQL:
115-
super().__str__()
139+
def format_array(self, value: Sequence[Any]) -> str:
140+
raise NotImplementedError(f"format_array not implemented for engine={self.engine_name}.")
116141

117-
if not self.data:
118-
return "MAP()"
142+
def format_dict(self, value: Dict[Any, Any]) -> str:
143+
raise NotImplementedError(f"format_dict not implemented for engine={self.engine_name}.")
119144

120-
map_keys = list(self.data.keys())
121-
key_type = type(map_keys[0])
122-
for key in map_keys:
123-
if isinstance(key, _NullType):
124-
raise TypeError("Map key cannot be null.")
125-
if not isinstance(key, key_type):
126-
raise TypeError("All Map key elements must be the same type.")
127145

128-
map_values = list(self.data.values())
129-
return f"MAP(ARRAY [{', '.join(map(str, map_keys))}], ARRAY [{', '.join(map(str, map_values))}])"
130-
131-
132-
_FORMATS: Dict[Type[Any], Type[_AbstractType[_PythonType]]] = { # type: ignore[valid-type]
133-
bool: _BooleanType,
134-
str: _StringType,
135-
int: _IntegerType,
136-
datetime.datetime: _TimestampType,
137-
datetime.date: _DateType,
138-
decimal.Decimal: _DecimalType,
139-
float: _FloatType,
140-
}
141-
142-
_ARRAY_FORMATS: Dict[Type[Any], Type[_AbstractType[_PythonType]]] = { # type: ignore[valid-type]
143-
list: _ArrayType,
144-
tuple: _ArrayType,
145-
set: _ArrayType,
146-
}
147-
148-
_MAP_FORMATS: Dict[Type[Any], Type[_AbstractType[_PythonType]]] = { # type: ignore[valid-type]
149-
dict: _MapType,
150-
}
151-
152-
153-
def _create_abstract_type(
154-
data: _PythonType,
155-
engine: _EngineType,
156-
) -> _AbstractType[_PythonType]:
157-
if data is None:
158-
return _NullType(data=data, engine=engine)
159-
160-
for python_type, format_type in _FORMATS.items():
161-
if isinstance(data, python_type):
162-
return format_type(data=data, engine=engine)
163-
164-
for python_type, format_type in _ARRAY_FORMATS.items():
165-
if isinstance(data, python_type):
166-
return format_type(
167-
[_create_abstract_type(item, engine=engine) for item in data],
168-
engine=engine,
169-
)
170-
171-
for python_type, format_type in _MAP_FORMATS.items():
172-
if isinstance(data, python_type):
173-
return format_type(
174-
data={
175-
_create_abstract_type(mk, engine=engine): _create_abstract_type(mv, engine=engine)
176-
for mk, mv in data.items()
177-
},
178-
engine=engine,
179-
)
180-
181-
raise TypeError(f"Unsupported type {type(data)} in parameter.")
182-
183-
184-
def _format_parameters(params: Dict[str, Any], engine: _EngineType) -> Dict[str, Any]:
146+
def _format_parameters(params: Dict[str, Any], engine: _Engine) -> Dict[str, Any]:
185147
processed_params = {}
186148

187149
for k, v in params.items():
188-
abs_type = _create_abstract_type(data=v, engine=engine)
189-
processed_params[k] = str(abs_type)
150+
processed_params[k] = engine.format(data=v)
190151

191152
return processed_params
192153

193154

194155
_PATTERN = re.compile(r":([A-Za-z0-9_]+)(?![A-Za-z0-9_])")
195156

196157

197-
def _process_sql_params(sql: str, params: Optional[Dict[str, Any]], engine: _EngineType = _EngineType.PRESTO) -> str:
158+
def _create_engine(engine_type: _EngineTypeLiteral) -> _Engine:
159+
if engine_type == "hive":
160+
return _HiveEngine()
161+
162+
if engine_type == "presto":
163+
return _PrestoEngine()
164+
165+
if engine_type == "partiql":
166+
return _PartiQLEngine()
167+
168+
raise exceptions.InvalidArgumentValue(f"Unknown engine type: {engine_type}")
169+
170+
171+
def _process_sql_params(sql: str, params: Optional[Dict[str, Any]], engine_type: _EngineTypeLiteral = "presto") -> str:
198172
if params is None:
199173
params = {}
200174

175+
engine = _create_engine(engine_type)
201176
processed_params = _format_parameters(params, engine=engine)
202177

203178
def replace(match: re.Match) -> str: # type: ignore[type-arg]

awswrangler/lakeformation/_read.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from awswrangler._config import apply_configs
1313
from awswrangler._distributed import engine
1414
from awswrangler._executor import _BaseExecutor, _get_executor
15-
from awswrangler._sql_formatter import _EngineType, _process_sql_params
15+
from awswrangler._sql_formatter import _process_sql_params
1616
from awswrangler.catalog._utils import _catalog_id, _transaction_id
1717
from awswrangler.lakeformation._utils import commit_transaction, start_transaction, wait_query
1818

@@ -177,7 +177,7 @@ def read_sql_query(
177177
client_lakeformation = _utils.client(service_name="lakeformation", session=boto3_session)
178178
commit_trans: bool = False
179179

180-
sql = _process_sql_params(sql, params, engine=_EngineType.PARTIQL)
180+
sql = _process_sql_params(sql, params, engine_type="partiql")
181181

182182
if not any([transaction_id, query_as_of_time]):
183183
_logger.debug("Neither `transaction_id` nor `query_as_of_time` were specified, starting transaction")

0 commit comments

Comments
 (0)