Skip to content

Commit 873c570

Browse files
author
Jesse
authored
SQLAlchemy 2: Finish implementing all of ComponentReflectionTest (#251)
Signed-off-by: Jesse Whitehouse <[email protected]>
1 parent 2d4e39b commit 873c570

File tree

5 files changed

+262
-118
lines changed

5 files changed

+262
-118
lines changed

src/databricks/sqlalchemy/_ddl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ def post_create_table(self, table):
1919
return " USING DELTA"
2020

2121
def visit_unique_constraint(self, constraint, **kw):
22-
logger.warn("Databricks does not support unique constraints")
22+
logger.warning("Databricks does not support unique constraints")
2323
pass
2424

2525
def visit_check_constraint(self, constraint, **kw):
26-
logger.warn("Databricks does not support check constraints")
26+
logger.warning("This dialect does not support check constraints")
2727
pass
2828

2929
def visit_identity_column(self, identity, **kw):

src/databricks/sqlalchemy/_parse.py

Lines changed: 106 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
from typing import List, Optional, Dict
22
import re
33

4+
import sqlalchemy
45
from sqlalchemy.engine import CursorResult
6+
from sqlalchemy.engine.interfaces import ReflectedColumn
57

68
"""
79
This module contains helper functions that can parse the contents
810
of metadata and exceptions received from DBR. These are mostly just
911
wrappers around regexes.
1012
"""
1113

14+
1215
def _match_table_not_found_string(message: str) -> bool:
1316
"""Return True if the message contains a substring indicating that a table was not found"""
1417

@@ -22,9 +25,10 @@ def _match_table_not_found_string(message: str) -> bool:
2225
)
2326

2427

25-
def _describe_table_extended_result_to_dict_list(result: CursorResult) -> List[Dict[str, str]]:
26-
"""Transform the CursorResult of DESCRIBE TABLE EXTENDED into a list of Dictionaries
27-
"""
28+
def _describe_table_extended_result_to_dict_list(
29+
result: CursorResult,
30+
) -> List[Dict[str, str]]:
31+
"""Transform the CursorResult of DESCRIBE TABLE EXTENDED into a list of Dictionaries"""
2832

2933
rows_to_return = []
3034
for row in result:
@@ -68,22 +72,23 @@ def extract_three_level_identifier_from_constraint_string(input_str: str) -> dic
6872
"""
6973
pat = re.compile(r"REFERENCES\s+(.*?)\s*\(")
7074
matches = pat.findall(input_str)
71-
75+
7276
if not matches:
7377
return None
74-
78+
7579
first_match = matches[0]
7680
parts = first_match.split(".")
7781

78-
def strip_backticks(input:str):
82+
def strip_backticks(input: str):
7983
return input.replace("`", "")
80-
84+
8185
return {
82-
"catalog": strip_backticks(parts[0]),
86+
"catalog": strip_backticks(parts[0]),
8387
"schema": strip_backticks(parts[1]),
84-
"table": strip_backticks(parts[2])
88+
"table": strip_backticks(parts[2]),
8589
}
8690

91+
8792
def _parse_fk_from_constraint_string(constraint_str: str) -> dict:
8893
"""Build a dictionary of foreign key constraint information from a constraint string.
8994
@@ -133,6 +138,7 @@ def _parse_fk_from_constraint_string(constraint_str: str) -> dict:
133138
"referred_schema": referred_schema,
134139
}
135140

141+
136142
def build_fk_dict(
137143
fk_name: str, fk_constraint_string: str, schema_name: Optional[str]
138144
) -> dict:
@@ -172,6 +178,7 @@ def build_fk_dict(
172178

173179
return complete_foreign_key_dict
174180

181+
175182
def _parse_pk_columns_from_constraint_string(constraint_str: str) -> List[str]:
176183
"""Build a list of constrained columns from a constraint string returned by DESCRIBE TABLE EXTENDED
177184
@@ -188,21 +195,23 @@ def _parse_pk_columns_from_constraint_string(constraint_str: str) -> List[str]:
188195

189196
return _extracted
190197

198+
191199
def build_pk_dict(pk_name: str, pk_constraint_string: str) -> dict:
192200
"""Given a primary key name and a primary key constraint string, return a dictionary
193201
with the following keys:
194-
202+
195203
constrained_columns
196204
A list of string column names that make up the primary key
197205
198206
name
199207
The name of the primary key constraint
200208
"""
201-
209+
202210
constrained_columns = _parse_pk_columns_from_constraint_string(pk_constraint_string)
203211

204212
return {"constrained_columns": constrained_columns, "name": pk_name}
205-
213+
214+
206215
def match_dte_rows_by_value(dte_output: List[Dict[str, str]], match: str) -> List[dict]:
207216
"""Return a list of dictionaries containing only the col_name:data_type pairs where the `data_type`
208217
value contains the match argument.
@@ -221,9 +230,10 @@ def match_dte_rows_by_value(dte_output: List[Dict[str, str]], match: str) -> Lis
221230
for row_dict in dte_output:
222231
if match in row_dict["data_type"]:
223232
output_rows.append(row_dict)
224-
233+
225234
return output_rows
226235

236+
227237
def get_fk_strings_from_dte_output(dte_output: List[List]) -> List[dict]:
228238
"""If the DESCRIBE TABLE EXTENDED output contains foreign key constraints, return a list of dictionaries,
229239
one dictionary per defined constraint
@@ -233,8 +243,10 @@ def get_fk_strings_from_dte_output(dte_output: List[List]) -> List[dict]:
233243

234244
return output
235245

236-
237-
def get_pk_strings_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional[List[dict]]:
246+
247+
def get_pk_strings_from_dte_output(
248+
dte_output: List[Dict[str, str]]
249+
) -> Optional[List[dict]]:
238250
"""If the DESCRIBE TABLE EXTENDED output contains primary key constraints, return a list of dictionaries,
239251
one dictionary per defined constraint.
240252
@@ -244,3 +256,82 @@ def get_pk_strings_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional
244256
output = match_dte_rows_by_value(dte_output, "PRIMARY KEY")
245257

246258
return output
259+
260+
261+
# The keys of this dictionary are the values we expect to see in a
262+
# TGetColumnsRequest's .TYPE_NAME attribute.
263+
# These are enumerated in ttypes.py as class TTypeId.
264+
# TODO: confirm that all types in TTypeId are included here.
265+
GET_COLUMNS_TYPE_MAP = {
266+
"boolean": sqlalchemy.types.Boolean,
267+
"smallint": sqlalchemy.types.SmallInteger,
268+
"int": sqlalchemy.types.Integer,
269+
"bigint": sqlalchemy.types.BigInteger,
270+
"float": sqlalchemy.types.Float,
271+
"double": sqlalchemy.types.Float,
272+
"string": sqlalchemy.types.String,
273+
"varchar": sqlalchemy.types.String,
274+
"char": sqlalchemy.types.String,
275+
"binary": sqlalchemy.types.String,
276+
"array": sqlalchemy.types.String,
277+
"map": sqlalchemy.types.String,
278+
"struct": sqlalchemy.types.String,
279+
"uniontype": sqlalchemy.types.String,
280+
"decimal": sqlalchemy.types.Numeric,
281+
"timestamp": sqlalchemy.types.DateTime,
282+
"date": sqlalchemy.types.Date,
283+
}
284+
285+
286+
def parse_numeric_type_precision_and_scale(type_name_str):
287+
"""Return an intantiated sqlalchemy Numeric() type that preserves the precision and scale indicated
288+
in the output from TGetColumnsRequest.
289+
290+
type_name_str
291+
The value of TGetColumnsReq.TYPE_NAME.
292+
293+
If type_name_str is "DECIMAL(18,5) returns sqlalchemy.types.Numeric(18,5)
294+
"""
295+
296+
pattern = re.compile(r"DECIMAL\((\d+,\d+)\)")
297+
match = re.search(pattern, type_name_str)
298+
precision_and_scale = match.group(1)
299+
precision, scale = tuple(precision_and_scale.split(","))
300+
301+
return sqlalchemy.types.Numeric(int(precision), int(scale))
302+
303+
304+
def parse_column_info_from_tgetcolumnsresponse(thrift_resp_row) -> ReflectedColumn:
305+
"""Returns a dictionary of the ReflectedColumn schema parsed from
306+
a single of the result of a TGetColumnsRequest thrift RPC
307+
"""
308+
309+
pat = re.compile(r"^\w+")
310+
_raw_col_type = re.search(pat, thrift_resp_row.TYPE_NAME).group(0).lower()
311+
_col_type = GET_COLUMNS_TYPE_MAP[_raw_col_type]
312+
313+
if _raw_col_type == "decimal":
314+
final_col_type = parse_numeric_type_precision_and_scale(
315+
thrift_resp_row.TYPE_NAME
316+
)
317+
else:
318+
final_col_type = _col_type
319+
320+
# See comments about autoincrement in test_suite.py
321+
# Since Databricks SQL doesn't currently support inline AUTOINCREMENT declarations
322+
# the autoincrement must be manually declared with an Identity() construct in SQLAlchemy
323+
# Other dialects can perform this extra Identity() step automatically. But that is not
324+
# implemented in the Databricks dialect right now. So autoincrement is currently always False.
325+
# It's not clear what IS_AUTO_INCREMENT in the thrift response actually reflects or whether
326+
# it ever returns a `YES`.
327+
328+
# Per the guidance in SQLAlchemy's docstrings, we prefer to not even include an autoincrement
329+
# key in this dictionary.
330+
this_column = {
331+
"name": thrift_resp_row.COLUMN_NAME,
332+
"type": final_col_type,
333+
"nullable": bool(thrift_resp_row.NULLABLE),
334+
"default": thrift_resp_row.COLUMN_DEF,
335+
}
336+
337+
return this_column

0 commit comments

Comments
 (0)