11from typing import List , Optional , Dict
22import re
33
4+ import sqlalchemy
45from sqlalchemy .engine import CursorResult
6+ from sqlalchemy .engine .interfaces import ReflectedColumn
57
68"""
79This module contains helper functions that can parse the contents
810of metadata and exceptions received from DBR. These are mostly just
911wrappers around regexes.
1012"""
1113
14+
1215def _match_table_not_found_string (message : str ) -> bool :
1316 """Return True if the message contains a substring indicating that a table was not found"""
1417
@@ -22,9 +25,10 @@ def _match_table_not_found_string(message: str) -> bool:
2225 )
2326
2427
25- def _describe_table_extended_result_to_dict_list (result : CursorResult ) -> List [Dict [str , str ]]:
26- """Transform the CursorResult of DESCRIBE TABLE EXTENDED into a list of Dictionaries
27- """
28+ def _describe_table_extended_result_to_dict_list (
29+ result : CursorResult ,
30+ ) -> List [Dict [str , str ]]:
31+ """Transform the CursorResult of DESCRIBE TABLE EXTENDED into a list of Dictionaries"""
2832
2933 rows_to_return = []
3034 for row in result :
@@ -68,22 +72,23 @@ def extract_three_level_identifier_from_constraint_string(input_str: str) -> dic
6872 """
6973 pat = re .compile (r"REFERENCES\s+(.*?)\s*\(" )
7074 matches = pat .findall (input_str )
71-
75+
7276 if not matches :
7377 return None
74-
78+
7579 first_match = matches [0 ]
7680 parts = first_match .split ("." )
7781
78- def strip_backticks (input :str ):
82+ def strip_backticks (input : str ):
7983 return input .replace ("`" , "" )
80-
84+
8185 return {
82- "catalog" : strip_backticks (parts [0 ]),
86+ "catalog" : strip_backticks (parts [0 ]),
8387 "schema" : strip_backticks (parts [1 ]),
84- "table" : strip_backticks (parts [2 ])
88+ "table" : strip_backticks (parts [2 ]),
8589 }
8690
91+
8792def _parse_fk_from_constraint_string (constraint_str : str ) -> dict :
8893 """Build a dictionary of foreign key constraint information from a constraint string.
8994
@@ -133,6 +138,7 @@ def _parse_fk_from_constraint_string(constraint_str: str) -> dict:
133138 "referred_schema" : referred_schema ,
134139 }
135140
141+
136142def build_fk_dict (
137143 fk_name : str , fk_constraint_string : str , schema_name : Optional [str ]
138144) -> dict :
@@ -172,6 +178,7 @@ def build_fk_dict(
172178
173179 return complete_foreign_key_dict
174180
181+
175182def _parse_pk_columns_from_constraint_string (constraint_str : str ) -> List [str ]:
176183 """Build a list of constrained columns from a constraint string returned by DESCRIBE TABLE EXTENDED
177184
@@ -188,21 +195,23 @@ def _parse_pk_columns_from_constraint_string(constraint_str: str) -> List[str]:
188195
189196 return _extracted
190197
198+
191199def build_pk_dict (pk_name : str , pk_constraint_string : str ) -> dict :
192200 """Given a primary key name and a primary key constraint string, return a dictionary
193201 with the following keys:
194-
202+
195203 constrained_columns
196204 A list of string column names that make up the primary key
197205
198206 name
199207 The name of the primary key constraint
200208 """
201-
209+
202210 constrained_columns = _parse_pk_columns_from_constraint_string (pk_constraint_string )
203211
204212 return {"constrained_columns" : constrained_columns , "name" : pk_name }
205-
213+
214+
206215def match_dte_rows_by_value (dte_output : List [Dict [str , str ]], match : str ) -> List [dict ]:
207216 """Return a list of dictionaries containing only the col_name:data_type pairs where the `data_type`
208217 value contains the match argument.
@@ -221,9 +230,10 @@ def match_dte_rows_by_value(dte_output: List[Dict[str, str]], match: str) -> Lis
221230 for row_dict in dte_output :
222231 if match in row_dict ["data_type" ]:
223232 output_rows .append (row_dict )
224-
233+
225234 return output_rows
226235
236+
227237def get_fk_strings_from_dte_output (dte_output : List [List ]) -> List [dict ]:
228238 """If the DESCRIBE TABLE EXTENDED output contains foreign key constraints, return a list of dictionaries,
229239 one dictionary per defined constraint
@@ -233,8 +243,10 @@ def get_fk_strings_from_dte_output(dte_output: List[List]) -> List[dict]:
233243
234244 return output
235245
236-
237- def get_pk_strings_from_dte_output (dte_output : List [Dict [str , str ]]) -> Optional [List [dict ]]:
246+
247+ def get_pk_strings_from_dte_output (
248+ dte_output : List [Dict [str , str ]]
249+ ) -> Optional [List [dict ]]:
238250 """If the DESCRIBE TABLE EXTENDED output contains primary key constraints, return a list of dictionaries,
239251 one dictionary per defined constraint.
240252
@@ -244,3 +256,82 @@ def get_pk_strings_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional
244256 output = match_dte_rows_by_value (dte_output , "PRIMARY KEY" )
245257
246258 return output
259+
260+
261+ # The keys of this dictionary are the values we expect to see in a
262+ # TGetColumnsRequest's .TYPE_NAME attribute.
263+ # These are enumerated in ttypes.py as class TTypeId.
264+ # TODO: confirm that all types in TTypeId are included here.
265+ GET_COLUMNS_TYPE_MAP = {
266+ "boolean" : sqlalchemy .types .Boolean ,
267+ "smallint" : sqlalchemy .types .SmallInteger ,
268+ "int" : sqlalchemy .types .Integer ,
269+ "bigint" : sqlalchemy .types .BigInteger ,
270+ "float" : sqlalchemy .types .Float ,
271+ "double" : sqlalchemy .types .Float ,
272+ "string" : sqlalchemy .types .String ,
273+ "varchar" : sqlalchemy .types .String ,
274+ "char" : sqlalchemy .types .String ,
275+ "binary" : sqlalchemy .types .String ,
276+ "array" : sqlalchemy .types .String ,
277+ "map" : sqlalchemy .types .String ,
278+ "struct" : sqlalchemy .types .String ,
279+ "uniontype" : sqlalchemy .types .String ,
280+ "decimal" : sqlalchemy .types .Numeric ,
281+ "timestamp" : sqlalchemy .types .DateTime ,
282+ "date" : sqlalchemy .types .Date ,
283+ }
284+
285+
286+ def parse_numeric_type_precision_and_scale (type_name_str ):
287+ """Return an intantiated sqlalchemy Numeric() type that preserves the precision and scale indicated
288+ in the output from TGetColumnsRequest.
289+
290+ type_name_str
291+ The value of TGetColumnsReq.TYPE_NAME.
292+
293+ If type_name_str is "DECIMAL(18,5) returns sqlalchemy.types.Numeric(18,5)
294+ """
295+
296+ pattern = re .compile (r"DECIMAL\((\d+,\d+)\)" )
297+ match = re .search (pattern , type_name_str )
298+ precision_and_scale = match .group (1 )
299+ precision , scale = tuple (precision_and_scale .split ("," ))
300+
301+ return sqlalchemy .types .Numeric (int (precision ), int (scale ))
302+
303+
304+ def parse_column_info_from_tgetcolumnsresponse (thrift_resp_row ) -> ReflectedColumn :
305+ """Returns a dictionary of the ReflectedColumn schema parsed from
306+ a single of the result of a TGetColumnsRequest thrift RPC
307+ """
308+
309+ pat = re .compile (r"^\w+" )
310+ _raw_col_type = re .search (pat , thrift_resp_row .TYPE_NAME ).group (0 ).lower ()
311+ _col_type = GET_COLUMNS_TYPE_MAP [_raw_col_type ]
312+
313+ if _raw_col_type == "decimal" :
314+ final_col_type = parse_numeric_type_precision_and_scale (
315+ thrift_resp_row .TYPE_NAME
316+ )
317+ else :
318+ final_col_type = _col_type
319+
320+ # See comments about autoincrement in test_suite.py
321+ # Since Databricks SQL doesn't currently support inline AUTOINCREMENT declarations
322+ # the autoincrement must be manually declared with an Identity() construct in SQLAlchemy
323+ # Other dialects can perform this extra Identity() step automatically. But that is not
324+ # implemented in the Databricks dialect right now. So autoincrement is currently always False.
325+ # It's not clear what IS_AUTO_INCREMENT in the thrift response actually reflects or whether
326+ # it ever returns a `YES`.
327+
328+ # Per the guidance in SQLAlchemy's docstrings, we prefer to not even include an autoincrement
329+ # key in this dictionary.
330+ this_column = {
331+ "name" : thrift_resp_row .COLUMN_NAME ,
332+ "type" : final_col_type ,
333+ "nullable" : bool (thrift_resp_row .NULLABLE ),
334+ "default" : thrift_resp_row .COLUMN_DEF ,
335+ }
336+
337+ return this_column
0 commit comments