|
2 | 2 | import datetime |
3 | 3 | import decimal |
4 | 4 | import re |
5 | | -from enum import Enum |
6 | | -from typing import Any, Dict, Generic, Optional, Sequence, Type, TypeVar |
| 5 | +from abc import ABC, abstractmethod |
| 6 | +from typing import Any, Callable, Dict, Optional, Sequence, Type |
7 | 7 |
|
| 8 | +from typing_extensions import Literal |
8 | 9 |
|
9 | | -class _EngineType(Enum): |
10 | | - PRESTO = "presto" |
11 | | - HIVE = "hive" |
12 | | - PARTIQL = "partiql" |
| 10 | +from awswrangler import exceptions |
13 | 11 |
|
14 | | - def __str__(self) -> str: |
15 | | - return self.value |
| 12 | +_EngineTypeLiteral = Literal["presto", "hive", "partiql"] |
16 | 13 |
|
17 | 14 |
|
18 | | -_NoneType = type(None) |
19 | | -_PythonType = TypeVar("_PythonType") |
20 | | -_PythonTypeMapValue = TypeVar("_PythonTypeMapValue") |
| 15 | +class _Engine(ABC): |
| 16 | + def __init__(self, engine_name: _EngineTypeLiteral) -> None: |
| 17 | + self.engine_name = engine_name |
21 | 18 |
|
| 19 | + def format_null(self, value: None = None) -> str: |
| 20 | + return "NULL" |
22 | 21 |
|
23 | | -class _AbstractType(Generic[_PythonType]): |
24 | | - def __init__(self, data: _PythonType, engine: _EngineType): |
25 | | - self.data: _PythonType = data |
26 | | - self.engine: _EngineType = engine |
| 22 | + @abstractmethod |
| 23 | + def format_string(self, value: str) -> str: |
| 24 | + pass |
27 | 25 |
|
28 | | - def __str__(self) -> str: |
29 | | - raise NotImplementedError(f"{type(self)} not implemented for engine={self.engine}.") |
| 26 | + def format_bool(self, value: bool) -> str: |
| 27 | + return str(value).upper() |
30 | 28 |
|
| 29 | + def format_integer(self, value: int) -> str: |
| 30 | + return str(value) |
31 | 31 |
|
32 | | -class _NullType(_AbstractType[_NoneType]): |
33 | | - def __str__(self) -> str: |
34 | | - if self.engine == _EngineType.PARTIQL: |
35 | | - return "null" |
| 32 | + def format_float(self, value: float) -> str: |
| 33 | + return f"{value:f}" |
36 | 34 |
|
37 | | - return "NULL" |
| 35 | + def format_decimal(self, value: decimal.Decimal) -> str: |
| 36 | + return f"DECIMAL '{value:f}'" |
38 | 37 |
|
| 38 | + def format_timestamp(self, value: datetime.datetime) -> str: |
| 39 | + if value.tzinfo is not None: |
| 40 | + raise TypeError(f"Supports only timezone aware datatype, got {value}.") |
39 | 41 |
|
40 | | -class _StringType(_AbstractType[str]): |
41 | | - supported_formats = {"s", "i"} |
| 42 | + return f"TIMESTAMP '{value.isoformat(sep=' ', timespec='milliseconds')}'" |
42 | 43 |
|
43 | | - def __str__(self) -> str: |
44 | | - if self.engine in [_EngineType.PRESTO, _EngineType.PARTIQL]: |
45 | | - return f"""'{self.data.replace("'", "''")}'""" |
| 44 | + def format_date(self, value: datetime.date) -> str: |
| 45 | + return f"DATE '{value.isoformat()}'" |
46 | 46 |
|
47 | | - if self.engine == _EngineType.HIVE: |
48 | | - return "'{}'".format( |
49 | | - self.data.replace("\\", "\\\\") |
50 | | - .replace("'", "\\'") |
51 | | - .replace("\r", "\\r") |
52 | | - .replace("\n", "\\n") |
53 | | - .replace("\t", "\\t") |
54 | | - ) |
| 47 | + def format_array(self, value: Sequence[Any]) -> str: |
| 48 | + return f"ARRAY [{', '.join(map(self.format, value))}]" |
55 | 49 |
|
56 | | - return super().__str__() |
| 50 | + def format_dict(self, value: Dict[Any, Any]) -> str: |
| 51 | + if not value: |
| 52 | + return "MAP()" |
57 | 53 |
|
| 54 | + map_keys = list(value.keys()) |
| 55 | + key_type = type(map_keys[0]) |
| 56 | + for key in map_keys: |
| 57 | + if key is None: |
| 58 | + raise TypeError("Map key cannot be null.") |
| 59 | + if not isinstance(key, key_type): |
| 60 | + raise TypeError("All Map key elements must be the same type.") |
58 | 61 |
|
59 | | -class _BooleanType(_AbstractType[bool]): |
60 | | - def __str__(self) -> str: |
61 | | - if self.engine == _EngineType.PARTIQL: |
62 | | - return "1" if self.data else "0" |
| 62 | + map_values = list(value.values()) |
| 63 | + return ( |
| 64 | + f"MAP(ARRAY [{', '.join(map(self.format, map_keys))}], ARRAY [{', '.join(map(self.format, map_values))}])" |
| 65 | + ) |
63 | 66 |
|
64 | | - return str(self.data).upper() |
| 67 | + def format(self, data: Any) -> str: |
| 68 | + formats_dict: Dict[Type[Any], Callable[[Any], str]] = { |
| 69 | + bool: self.format_bool, |
| 70 | + str: self.format_string, |
| 71 | + int: self.format_integer, |
| 72 | + datetime.datetime: self.format_timestamp, |
| 73 | + datetime.date: self.format_date, |
| 74 | + decimal.Decimal: self.format_decimal, |
| 75 | + float: self.format_float, |
| 76 | + list: self.format_array, |
| 77 | + tuple: self.format_array, |
| 78 | + set: self.format_array, |
| 79 | + dict: self.format_dict, |
| 80 | + } |
65 | 81 |
|
| 82 | + if data is None: |
| 83 | + return self.format_null() |
66 | 84 |
|
67 | | -class _IntegerType(_AbstractType[int]): |
68 | | - def __str__(self) -> str: |
69 | | - return str(self.data) |
| 85 | + for python_type, format_func in formats_dict.items(): |
| 86 | + if isinstance(data, python_type): |
| 87 | + return format_func(data) |
70 | 88 |
|
| 89 | + raise TypeError(f"Unsupported type {type(data)} in parameter.") |
71 | 90 |
|
72 | | -class _FloatType(_AbstractType[float]): |
73 | | - def __str__(self) -> str: |
74 | | - return f"{self.data:f}" |
75 | 91 |
|
| 92 | +class _PrestoEngine(_Engine): |
| 93 | + def __init__(self) -> None: |
| 94 | + super().__init__("presto") |
76 | 95 |
|
77 | | -class _DecimalType(_AbstractType[decimal.Decimal]): |
78 | | - def __str__(self) -> str: |
79 | | - if self.engine == _EngineType.PARTIQL: |
80 | | - return f"'{self.data}'" |
| 96 | + def format_string(self, value: str) -> str: |
| 97 | + return f"""'{value.replace("'", "''")}'""" |
81 | 98 |
|
82 | | - return f"DECIMAL '{self.data:f}'" |
83 | 99 |
|
| 100 | +class _HiveEngine(_Engine): |
| 101 | + def __init__(self) -> None: |
| 102 | + super().__init__("hive") |
84 | 103 |
|
85 | | -class _TimestampType(_AbstractType[datetime.datetime]): |
86 | | - def __str__(self) -> str: |
87 | | - if self.data.tzinfo is not None: |
88 | | - raise TypeError(f"Supports only timezone aware datatype, got {self.data}.") |
| 104 | + def format_string(self, value: str) -> str: |
| 105 | + return "'{}'".format( |
| 106 | + value.replace("\\", "\\\\") |
| 107 | + .replace("'", "\\'") |
| 108 | + .replace("\r", "\\r") |
| 109 | + .replace("\n", "\\n") |
| 110 | + .replace("\t", "\\t") |
| 111 | + ) |
89 | 112 |
|
90 | | - if self.engine == _EngineType.PARTIQL: |
91 | | - return f"'{self.data.isoformat()}'" |
92 | 113 |
|
93 | | - return f"TIMESTAMP '{self.data.isoformat(sep=' ', timespec='milliseconds')}'" |
| 114 | +class _PartiQLEngine(_Engine): |
| 115 | + def __init__(self) -> None: |
| 116 | + super().__init__("partiql") |
94 | 117 |
|
| 118 | + def format_null(self, value: None = None) -> str: |
| 119 | + return "null" |
95 | 120 |
|
96 | | -class _DateType(_AbstractType[datetime.date]): |
97 | | - def __str__(self) -> str: |
98 | | - if self.engine == _EngineType.PARTIQL: |
99 | | - return f"'{self.data.isoformat()}'" |
| 121 | + def format_string(self, value: str) -> str: |
| 122 | + return f"""'{value.replace("'", "''")}'""" |
100 | 123 |
|
101 | | - return f"DATE '{self.data.isoformat()}'" |
| 124 | + def format_bool(self, value: bool) -> str: |
| 125 | + return "1" if value else "0" |
102 | 126 |
|
| 127 | + def format_decimal(self, value: decimal.Decimal) -> str: |
| 128 | + return f"'{value}'" |
103 | 129 |
|
104 | | -class _ArrayType(_AbstractType[Sequence[_PythonType]]): |
105 | | - def __str__(self) -> str: |
106 | | - if self.engine == _EngineType.PARTIQL: |
107 | | - super().__str__() |
| 130 | + def format_timestamp(self, value: datetime.datetime) -> str: |
| 131 | + if value.tzinfo is not None: |
| 132 | + raise TypeError(f"Supports only timezone aware datatype, got {value}.") |
108 | 133 |
|
109 | | - return f"ARRAY [{', '.join(map(str, self.data))}]" |
| 134 | + return f"'{value.isoformat()}'" |
110 | 135 |
|
| 136 | + def format_date(self, value: datetime.date) -> str: |
| 137 | + return f"'{value.isoformat()}'" |
111 | 138 |
|
112 | | -class _MapType(_AbstractType[Dict[_PythonType, _PythonTypeMapValue]]): |
113 | | - def __str__(self) -> str: |
114 | | - if self.engine == _EngineType.PARTIQL: |
115 | | - super().__str__() |
| 139 | + def format_array(self, value: Sequence[Any]) -> str: |
| 140 | + raise NotImplementedError(f"format_array not implemented for engine={self.engine_name}.") |
116 | 141 |
|
117 | | - if not self.data: |
118 | | - return "MAP()" |
| 142 | + def format_dict(self, value: Dict[Any, Any]) -> str: |
| 143 | + raise NotImplementedError(f"format_dict not implemented for engine={self.engine_name}.") |
119 | 144 |
|
120 | | - map_keys = list(self.data.keys()) |
121 | | - key_type = type(map_keys[0]) |
122 | | - for key in map_keys: |
123 | | - if isinstance(key, _NullType): |
124 | | - raise TypeError("Map key cannot be null.") |
125 | | - if not isinstance(key, key_type): |
126 | | - raise TypeError("All Map key elements must be the same type.") |
127 | 145 |
|
128 | | - map_values = list(self.data.values()) |
129 | | - return f"MAP(ARRAY [{', '.join(map(str, map_keys))}], ARRAY [{', '.join(map(str, map_values))}])" |
130 | | - |
131 | | - |
132 | | -_FORMATS: Dict[Type[Any], Type[_AbstractType[_PythonType]]] = { # type: ignore[valid-type] |
133 | | - bool: _BooleanType, |
134 | | - str: _StringType, |
135 | | - int: _IntegerType, |
136 | | - datetime.datetime: _TimestampType, |
137 | | - datetime.date: _DateType, |
138 | | - decimal.Decimal: _DecimalType, |
139 | | - float: _FloatType, |
140 | | -} |
141 | | - |
142 | | -_ARRAY_FORMATS: Dict[Type[Any], Type[_AbstractType[_PythonType]]] = { # type: ignore[valid-type] |
143 | | - list: _ArrayType, |
144 | | - tuple: _ArrayType, |
145 | | - set: _ArrayType, |
146 | | -} |
147 | | - |
148 | | -_MAP_FORMATS: Dict[Type[Any], Type[_AbstractType[_PythonType]]] = { # type: ignore[valid-type] |
149 | | - dict: _MapType, |
150 | | -} |
151 | | - |
152 | | - |
153 | | -def _create_abstract_type( |
154 | | - data: _PythonType, |
155 | | - engine: _EngineType, |
156 | | -) -> _AbstractType[_PythonType]: |
157 | | - if data is None: |
158 | | - return _NullType(data=data, engine=engine) |
159 | | - |
160 | | - for python_type, format_type in _FORMATS.items(): |
161 | | - if isinstance(data, python_type): |
162 | | - return format_type(data=data, engine=engine) |
163 | | - |
164 | | - for python_type, format_type in _ARRAY_FORMATS.items(): |
165 | | - if isinstance(data, python_type): |
166 | | - return format_type( |
167 | | - [_create_abstract_type(item, engine=engine) for item in data], |
168 | | - engine=engine, |
169 | | - ) |
170 | | - |
171 | | - for python_type, format_type in _MAP_FORMATS.items(): |
172 | | - if isinstance(data, python_type): |
173 | | - return format_type( |
174 | | - data={ |
175 | | - _create_abstract_type(mk, engine=engine): _create_abstract_type(mv, engine=engine) |
176 | | - for mk, mv in data.items() |
177 | | - }, |
178 | | - engine=engine, |
179 | | - ) |
180 | | - |
181 | | - raise TypeError(f"Unsupported type {type(data)} in parameter.") |
182 | | - |
183 | | - |
184 | | -def _format_parameters(params: Dict[str, Any], engine: _EngineType) -> Dict[str, Any]: |
| 146 | +def _format_parameters(params: Dict[str, Any], engine: _Engine) -> Dict[str, Any]: |
185 | 147 | processed_params = {} |
186 | 148 |
|
187 | 149 | for k, v in params.items(): |
188 | | - abs_type = _create_abstract_type(data=v, engine=engine) |
189 | | - processed_params[k] = str(abs_type) |
| 150 | + processed_params[k] = engine.format(data=v) |
190 | 151 |
|
191 | 152 | return processed_params |
192 | 153 |
|
193 | 154 |
|
194 | 155 | _PATTERN = re.compile(r":([A-Za-z0-9_]+)(?![A-Za-z0-9_])") |
195 | 156 |
|
196 | 157 |
|
197 | | -def _process_sql_params(sql: str, params: Optional[Dict[str, Any]], engine: _EngineType = _EngineType.PRESTO) -> str: |
| 158 | +def _create_engine(engine_type: _EngineTypeLiteral) -> _Engine: |
| 159 | + if engine_type == "hive": |
| 160 | + return _HiveEngine() |
| 161 | + |
| 162 | + if engine_type == "presto": |
| 163 | + return _PrestoEngine() |
| 164 | + |
| 165 | + if engine_type == "partiql": |
| 166 | + return _PartiQLEngine() |
| 167 | + |
| 168 | + raise exceptions.InvalidArgumentValue(f"Unknown engine type: {engine_type}") |
| 169 | + |
| 170 | + |
| 171 | +def _process_sql_params(sql: str, params: Optional[Dict[str, Any]], engine_type: _EngineTypeLiteral = "presto") -> str: |
198 | 172 | if params is None: |
199 | 173 | params = {} |
200 | 174 |
|
| 175 | + engine = _create_engine(engine_type) |
201 | 176 | processed_params = _format_parameters(params, engine=engine) |
202 | 177 |
|
203 | 178 | def replace(match: re.Match) -> str: # type: ignore[type-arg] |
|
0 commit comments