11"""Internal (private) Data Types Module."""
22
33import logging
4+ import re
45from decimal import Decimal
5- from typing import Dict , List , Optional , Tuple
6+ from typing import Any , Dict , List , Match , Optional , Sequence , Tuple
67
78import pandas as pd # type: ignore
89import pyarrow as pa # type: ignore
@@ -139,8 +140,10 @@ def pyarrow2athena(dtype: pa.DataType) -> str: # pylint: disable=too-many-branc
139140 return f"decimal({ dtype .precision } ,{ dtype .scale } )"
140141 if pa .types .is_list (dtype ):
141142 return f"array<{ pyarrow2athena (dtype = dtype .value_type )} >"
142- if pa .types .is_struct (dtype ): # pragma: no cover
143- return f"struct<{ ', ' .join ([f'{ f .name } : { pyarrow2athena (dtype = f .type )} ' for f in dtype ])} >"
143+ if pa .types .is_struct (dtype ):
144+ return f"struct<{ ', ' .join ([f'{ f .name } :{ pyarrow2athena (dtype = f .type )} ' for f in dtype ])} >"
145+ if pa .types .is_map (dtype ): # pragma: no cover
146+ return f"map<{ pyarrow2athena (dtype = dtype .key_type )} ,{ pyarrow2athena (dtype = dtype .item_type )} >"
144147 if dtype == pa .null ():
145148 raise exceptions .UndetectedType ("We can not infer the data type from an entire null object column" )
146149 raise exceptions .UnsupportedType (f"Unsupported Pyarrow type: { dtype } " ) # pragma: no cover
@@ -167,7 +170,7 @@ def pyarrow2pandas_extension( # pylint: disable=too-many-branches,too-many-retu
167170
168171def pyarrow2sqlalchemy ( # pylint: disable=too-many-branches,too-many-return-statements
169172 dtype : pa .DataType , db_type : str
170- ) -> VisitableType :
173+ ) -> Optional [ VisitableType ] :
171174 """Pyarrow to Athena data types conversion."""
172175 if pa .types .is_int8 (dtype ):
173176 return sqlalchemy .types .SmallInteger
@@ -207,14 +210,14 @@ def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-sta
207210 return sqlalchemy .types .Date
208211 if pa .types .is_binary (dtype ):
209212 if db_type == "redshift" :
210- raise exceptions .UnsupportedType (f "Binary columns are not supported for Redshift." ) # pragma: no cover
213+ raise exceptions .UnsupportedType ("Binary columns are not supported for Redshift." ) # pragma: no cover
211214 return sqlalchemy .types .Binary
212215 if pa .types .is_decimal (dtype ):
213216 return sqlalchemy .types .Numeric (precision = dtype .precision , scale = dtype .scale )
214217 if pa .types .is_dictionary (dtype ):
215218 return pyarrow2sqlalchemy (dtype = dtype .value_type , db_type = db_type )
216219 if dtype == pa .null (): # pragma: no cover
217- raise exceptions . UndetectedType ( "We can not infer the data type from an entire null object column" )
220+ return None
218221 raise exceptions .UnsupportedType (f"Unsupported Pyarrow type: { dtype } " ) # pragma: no cover
219222
220223
@@ -243,12 +246,23 @@ def pyarrow_types_from_pandas(
243246 else :
244247 cols .append (name )
245248
246- # Filling cols_dtypes and indexes
249+ # Filling cols_dtypes
250+ for col in cols :
251+ _logger .debug ("Inferring PyArrow type from column: %s" , col )
252+ try :
253+ schema : pa .Schema = pa .Schema .from_pandas (df = df [[col ]], preserve_index = False )
254+ except pa .ArrowInvalid as ex : # pragma: no cover
255+ cols_dtypes [col ] = process_not_inferred_dtype (ex )
256+ else :
257+ cols_dtypes [col ] = schema .field (col ).type
258+
259+ # Filling indexes
247260 indexes : List [str ] = []
248- for field in pa .Schema .from_pandas (df = df [cols ], preserve_index = index ):
249- name = str (field .name )
250- cols_dtypes [name ] = field .type
251- if (name not in df .columns ) and (index is True ):
261+ if index is True :
262+ for field in pa .Schema .from_pandas (df = df [[]], preserve_index = True ):
263+ name = str (field .name )
264+ _logger .debug ("Inferring PyArrow type from index: %s" , name )
265+ cols_dtypes [name ] = field .type
252266 indexes .append (name )
253267
254268 # Merging Index
@@ -257,10 +271,43 @@ def pyarrow_types_from_pandas(
257271 # Filling schema
258272 columns_types : Dict [str , pa .DataType ]
259273 columns_types = {n : cols_dtypes [n ] for n in sorted_cols }
260- _logger .debug (f "columns_types: { columns_types } " )
274+ _logger .debug ("columns_types: %s" , columns_types )
261275 return columns_types
262276
263277
278+ def process_not_inferred_dtype (ex : pa .ArrowInvalid ) -> pa .DataType :
279+ """Infer data type from PyArrow inference exception."""
280+ ex_str = str (ex )
281+ _logger .debug ("PyArrow was not able to infer data type:\n %s" , ex_str )
282+ match : Optional [Match ] = re .search (
283+ pattern = "Could not convert (.*) with type (.*): did not recognize "
284+ "Python value type when inferring an Arrow data type" ,
285+ string = ex_str ,
286+ )
287+ if match is None :
288+ raise ex # pragma: no cover
289+ groups : Optional [Sequence [str ]] = match .groups ()
290+ if groups is None :
291+ raise ex # pragma: no cover
292+ if len (groups ) != 2 :
293+ raise ex # pragma: no cover
294+ _logger .debug ("groups: %s" , groups )
295+ type_str : str = groups [1 ]
296+ if type_str == "UUID" :
297+ return pa .string ()
298+ raise ex # pragma: no cover
299+
300+
301+ def process_not_inferred_array (ex : pa .ArrowInvalid , values : Any ) -> pa .Array :
302+ """Infer `pyarrow.array` from PyArrow inference exception."""
303+ dtype = process_not_inferred_dtype (ex = ex )
304+ if dtype == pa .string ():
305+ array : pa .Array = pa .array (obj = [str (x ) for x in values ], type = dtype , safe = True )
306+ else :
307+ raise ex # pragma: no cover
308+ return array
309+
310+
264311def athena_types_from_pandas (
265312 df : pd .DataFrame , index : bool , dtype : Optional [Dict [str , str ]] = None , index_left : bool = False
266313) -> Dict [str , str ]:
@@ -275,7 +322,7 @@ def athena_types_from_pandas(
275322 athena_columns_types [k ] = casts [k ]
276323 else :
277324 athena_columns_types [k ] = pyarrow2athena (dtype = v )
278- _logger .debug (f "athena_columns_types: { athena_columns_types } " )
325+ _logger .debug ("athena_columns_types: %s" , athena_columns_types )
279326 return athena_columns_types
280327
281328
@@ -315,7 +362,7 @@ def pyarrow_schema_from_pandas(
315362 if (k in df .columns ) and (k not in ignore ):
316363 columns_types [k ] = athena2pyarrow (v )
317364 columns_types = {k : v for k , v in columns_types .items () if v is not None }
318- _logger .debug (f "columns_types: { columns_types } " )
365+ _logger .debug ("columns_types: %s" , columns_types )
319366 return pa .schema (fields = columns_types )
320367
321368
@@ -324,11 +371,11 @@ def athena_types_from_pyarrow_schema(
324371) -> Tuple [Dict [str , str ], Optional [Dict [str , str ]]]:
325372 """Extract the related Athena data types from any PyArrow Schema considering possible partitions."""
326373 columns_types : Dict [str , str ] = {str (f .name ): pyarrow2athena (dtype = f .type ) for f in schema }
327- _logger .debug (f "columns_types: { columns_types } " )
374+ _logger .debug ("columns_types: %s" , columns_types )
328375 partitions_types : Optional [Dict [str , str ]] = None
329376 if partitions is not None :
330377 partitions_types = {p .name : pyarrow2athena (p .dictionary .type ) for p in partitions }
331- _logger .debug (f "partitions_types: { partitions_types } " )
378+ _logger .debug ("partitions_types: %s" , partitions_types )
332379 return columns_types , partitions_types
333380
334381
@@ -372,7 +419,7 @@ def sqlalchemy_types_from_pandas(
372419 df : pd .DataFrame , db_type : str , dtype : Optional [Dict [str , VisitableType ]] = None
373420) -> Dict [str , VisitableType ]:
374421 """Extract the related SQLAlchemy data types from any Pandas DataFrame."""
375- casts : Dict [str , VisitableType ] = dtype if dtype else {}
422+ casts : Dict [str , VisitableType ] = dtype if dtype is not None else {}
376423 pa_columns_types : Dict [str , Optional [pa .DataType ]] = pyarrow_types_from_pandas (
377424 df = df , index = False , ignore_cols = list (casts .keys ())
378425 )
@@ -382,5 +429,5 @@ def sqlalchemy_types_from_pandas(
382429 sqlalchemy_columns_types [k ] = casts [k ]
383430 else :
384431 sqlalchemy_columns_types [k ] = pyarrow2sqlalchemy (dtype = v , db_type = db_type )
385- _logger .debug (f "sqlalchemy_columns_types: { sqlalchemy_columns_types } " )
432+ _logger .debug ("sqlalchemy_columns_types: %s" , sqlalchemy_columns_types )
386433 return sqlalchemy_columns_types
0 commit comments