55from datetime import date , datetime , timezone
66from decimal import Decimal
77from enum import Enum
8- from typing import List , Optional , Sequence , Union
8+ from io import StringIO
9+ from typing import Any , Dict , List , Optional , Sequence , Union
910
1011from sqlparse import parse as parse_sql # type: ignore
1112from sqlparse .sql import ( # type: ignore
@@ -62,8 +63,6 @@ def parse_datetime(datetime_string: str) -> datetime:
6263# These definitions are required by PEP-249
6364Date = date
6465
65- _AccountInfo = namedtuple ("_AccountInfo" , ["id" , "version" ])
66-
6766
6867def DateFromTicks (t : int ) -> date : # NOSONAR
6968 """Convert `ticks` to `date` for Firebolt DB."""
@@ -109,16 +108,28 @@ def Binary(value: str) -> bytes: # NOSONAR
109108)
110109
111110
112- class ARRAY :
111+ class ExtendedType :
112+ """Base type for all extended types in Firebolt (array, decimal, struct, etc.)."""
113+
114+ __name__ = "ExtendedType"
115+
116+ @staticmethod
117+ def is_valid_type (type_ : Any ) -> bool :
118+ return type_ in _col_types or isinstance (type_ , ExtendedType )
119+
120+ def __hash__ (self ) -> int :
121+ return hash (str (self ))
122+
123+
124+ class ARRAY (ExtendedType ):
113125 """Class for holding `array` column type information in Firebolt DB."""
114126
115127 __name__ = "Array"
116128 _prefix = "array("
117129
118- def __init__ (self , subtype : Union [type , ARRAY , DECIMAL ]):
119- assert (subtype in _col_types and subtype is not list ) or isinstance (
120- subtype , (ARRAY , DECIMAL )
121- ), f"Invalid array subtype: { str (subtype )} "
130+ def __init__ (self , subtype : Union [type , ExtendedType ]):
131+ if not self .is_valid_type (subtype ):
132+ raise ValueError (f"Invalid array subtype: { str (subtype )} " )
122133 self .subtype = subtype
123134
124135 def __str__ (self ) -> str :
@@ -130,7 +141,7 @@ def __eq__(self, other: object) -> bool:
130141 return other .subtype == self .subtype
131142
132143
133- class DECIMAL :
144+ class DECIMAL ( ExtendedType ) :
134145 """Class for holding `decimal` value information in Firebolt DB."""
135146
136147 __name__ = "Decimal"
@@ -143,15 +154,29 @@ def __init__(self, precision: int, scale: int):
143154 def __str__ (self ) -> str :
144155 return f"Decimal({ self .precision } , { self .scale } )"
145156
146- def __hash__ (self ) -> int :
147- return hash (str (self ))
148-
149157 def __eq__ (self , other : object ) -> bool :
150158 if not isinstance (other , DECIMAL ):
151159 return NotImplemented
152160 return other .precision == self .precision and other .scale == self .scale
153161
154162
163+ class STRUCT (ExtendedType ):
164+ __name__ = "Struct"
165+ _prefix = "struct("
166+
167+ def __init__ (self , fields : Dict [str , Union [type , ExtendedType ]]):
168+ for name , type_ in fields .items ():
169+ if not self .is_valid_type (type_ ):
170+ raise ValueError (f"Invalid struct field type: { str (type_ )} " )
171+ self .fields = fields
172+
173+ def __str__ (self ) -> str :
174+ return f"Struct({ ', ' .join (f'{ k } : { v } ' for k , v in self .fields .items ())} )"
175+
176+ def __eq__ (self , other : Any ) -> bool :
177+ return isinstance (other , STRUCT ) and other .fields == self .fields
178+
179+
155180NULLABLE_SUFFIX = "null"
156181
157182
@@ -206,7 +231,31 @@ def python_type(self) -> type:
206231 return types [self ]
207232
208233
209- def parse_type (raw_type : str ) -> Union [type , ARRAY , DECIMAL ]: # noqa: C901
234+ def split_struct_fields (raw_struct : str ) -> List [str ]:
235+ """Split raw struct inner fields string into a list of field definitions.
236+ >>> split_struct_fields("field1 int, field2 struct(field1 int, field2 text)")
237+ ["field1 int", "field2 struct(field1 int, field2 text)"]
238+ """
239+ balance = 0 # keep track of the level of nesting, and only split on level 0
240+ separator = ","
241+ res = []
242+ current = StringIO ()
243+ for i , ch in enumerate (raw_struct ):
244+ if ch == "(" :
245+ balance += 1
246+ elif ch == ")" :
247+ balance -= 1
248+ elif ch == separator and balance == 0 :
249+ res .append (current .getvalue ())
250+ current = StringIO ()
251+ continue
252+ current .write (ch )
253+
254+ res .append (current .getvalue ())
255+ return res
256+
257+
258+ def parse_type (raw_type : str ) -> Union [type , ExtendedType ]: # noqa: C901
210259 """Parse typename provided by query metadata into Python type."""
211260 if not isinstance (raw_type , str ):
212261 raise DataError (f"Invalid typename { str (raw_type )} : str expected" )
@@ -218,10 +267,20 @@ def parse_type(raw_type: str) -> Union[type, ARRAY, DECIMAL]: # noqa: C901
218267 try :
219268 prec_scale = raw_type [len (DECIMAL ._prefix ) : - 1 ].split ("," )
220269 precision , scale = int (prec_scale [0 ]), int (prec_scale [1 ])
270+ return DECIMAL (precision , scale )
221271 except (ValueError , IndexError ):
222272 pass
223- else :
224- return DECIMAL (precision , scale )
273+ # Handle structs
274+ if raw_type .startswith (STRUCT ._prefix ) and raw_type .endswith (")" ):
275+ try :
276+ fields_raw = split_struct_fields (raw_type [len (STRUCT ._prefix ) : - 1 ])
277+ fields = {}
278+ for f in fields_raw :
279+ name , type_ = f .strip ().split (" " , 1 )
280+ fields [name .strip ()] = parse_type (type_ .strip ())
281+ return STRUCT (fields )
282+ except ValueError :
283+ pass
225284 # Handle nullable
226285 if raw_type .endswith (NULLABLE_SUFFIX ):
227286 return parse_type (raw_type [: - len (NULLABLE_SUFFIX )].strip (" " ))
@@ -247,13 +306,13 @@ def _parse_bytea(str_value: str) -> bytes:
247306
248307def parse_value (
249308 value : RawColType ,
250- ctype : Union [type , ARRAY , DECIMAL ],
309+ ctype : Union [type , ExtendedType ],
251310) -> ColType :
252311 """Provided raw value, and Python type; parses first into Python value."""
253312 if value is None :
254313 return None
255314 if ctype in (int , str , float ):
256- assert isinstance (ctype , type )
315+ assert isinstance (ctype , type ) # assertion for mypy
257316 return ctype (value )
258317 if ctype is date :
259318 if not isinstance (value , str ):
@@ -273,11 +332,20 @@ def parse_value(
273332 raise DataError (f"Invalid bytea value { value } : str expected" )
274333 return _parse_bytea (value )
275334 if isinstance (ctype , DECIMAL ):
276- assert isinstance (value , (str , int ))
335+ if not isinstance (value , (str , int )):
336+ raise DataError (f"Invalid decimal value { value } : str or int expected" )
277337 return Decimal (value )
278338 if isinstance (ctype , ARRAY ):
279- assert isinstance (value , list )
339+ if not isinstance (value , list ):
340+ raise DataError (f"Invalid array value { value } : list expected" )
280341 return [parse_value (it , ctype .subtype ) for it in value ]
342+ if isinstance (ctype , STRUCT ):
343+ if not isinstance (value , dict ):
344+ raise DataError (f"Invalid struct value { value } : dict expected" )
345+ return {
346+ name : parse_value (value .get (name ), type_ )
347+ for name , type_ in ctype .fields .items ()
348+ }
281349 raise DataError (f"Unsupported data type returned: { ctype .__name__ } " )
282350
283351
0 commit comments