11"""Internal (private) Data Types Module."""
22
33import logging
4+ import re
45from decimal import Decimal
5- from typing import Dict , List , Optional , Tuple
6+ from typing import Any , Dict , List , Match , Optional , Sequence , Tuple
67
78import pandas as pd # type: ignore
89import pyarrow as pa # type: ignore
@@ -139,8 +140,10 @@ def pyarrow2athena(dtype: pa.DataType) -> str: # pylint: disable=too-many-branc
139140 return f"decimal({ dtype .precision } ,{ dtype .scale } )"
140141 if pa .types .is_list (dtype ):
141142 return f"array<{ pyarrow2athena (dtype = dtype .value_type )} >"
142- if pa .types .is_struct (dtype ): # pragma: no cover
143- return f"struct<{ ', ' .join ([f'{ f .name } : { pyarrow2athena (dtype = f .type )} ' for f in dtype ])} >"
143+ if pa .types .is_struct (dtype ):
144+ return f"struct<{ ', ' .join ([f'{ f .name } :{ pyarrow2athena (dtype = f .type )} ' for f in dtype ])} >"
145+ if pa .types .is_map (dtype ): # pragma: no cover
146+ return f"map<{ pyarrow2athena (dtype = dtype .key_type )} ,{ pyarrow2athena (dtype = dtype .item_type )} >"
144147 if dtype == pa .null ():
145148 raise exceptions .UndetectedType ("We can not infer the data type from an entire null object column" )
146149 raise exceptions .UnsupportedType (f"Unsupported Pyarrow type: { dtype } " ) # pragma: no cover
@@ -167,7 +170,7 @@ def pyarrow2pandas_extension( # pylint: disable=too-many-branches,too-many-retu
167170
168171def pyarrow2sqlalchemy ( # pylint: disable=too-many-branches,too-many-return-statements
169172 dtype : pa .DataType , db_type : str
170- ) -> VisitableType :
173+ ) -> Optional [ VisitableType ] :
171174 """Pyarrow to Athena data types conversion."""
172175 if pa .types .is_int8 (dtype ):
173176 return sqlalchemy .types .SmallInteger
@@ -214,7 +217,7 @@ def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-sta
214217 if pa .types .is_dictionary (dtype ):
215218 return pyarrow2sqlalchemy (dtype = dtype .value_type , db_type = db_type )
216219 if dtype == pa .null (): # pragma: no cover
217- raise exceptions . UndetectedType ( "We can not infer the data type from an entire null object column" )
220+ return None
218221 raise exceptions .UnsupportedType (f"Unsupported Pyarrow type: { dtype } " ) # pragma: no cover
219222
220223
@@ -243,12 +246,23 @@ def pyarrow_types_from_pandas(
243246 else :
244247 cols .append (name )
245248
246- # Filling cols_dtypes and indexes
249+ # Filling cols_dtypes
250+ for col in cols :
251+ _logger .debug ("Inferring PyArrow type from column: %s" , col )
252+ try :
253+ schema : pa .Schema = pa .Schema .from_pandas (df = df [[col ]], preserve_index = False )
254+ except pa .ArrowInvalid as ex : # pragma: no cover
255+ cols_dtypes [col ] = process_not_inferred_dtype (ex )
256+ else :
257+ cols_dtypes [col ] = schema .field (col ).type
258+
259+ # Filling indexes
247260 indexes : List [str ] = []
248- for field in pa .Schema .from_pandas (df = df [cols ], preserve_index = index ):
249- name = str (field .name )
250- cols_dtypes [name ] = field .type
251- if (name not in df .columns ) and (index is True ):
261+ if index is True :
262+ for field in pa .Schema .from_pandas (df = df [[]], preserve_index = True ):
263+ name = str (field .name )
264+ _logger .debug ("Inferring PyArrow type from index: %s" , name )
265+ cols_dtypes [name ] = field .type
252266 indexes .append (name )
253267
254268 # Merging Index
@@ -261,6 +275,39 @@ def pyarrow_types_from_pandas(
261275 return columns_types
262276
263277
278+ def process_not_inferred_dtype (ex : pa .ArrowInvalid ) -> pa .DataType :
279+ """Infer data type from PyArrow inference exception."""
280+ ex_str = str (ex )
281+ _logger .debug ("PyArrow was not able to infer data type:\n %s" , ex_str )
282+ match : Optional [Match ] = re .search (
283+ pattern = "Could not convert (.*) with type (.*): did not recognize "
284+ "Python value type when inferring an Arrow data type" ,
285+ string = ex_str ,
286+ )
287+ if match is None :
288+ raise ex # pragma: no cover
289+ groups : Optional [Sequence [str ]] = match .groups ()
290+ if groups is None :
291+ raise ex # pragma: no cover
292+ if len (groups ) != 2 :
293+ raise ex # pragma: no cover
294+ _logger .debug ("groups: %s" , groups )
295+ type_str : str = groups [1 ]
296+ if type_str == "UUID" :
297+ return pa .string ()
298+ raise ex # pragma: no cover
299+
300+
301+ def process_not_inferred_array (ex : pa .ArrowInvalid , values : Any ) -> pa .Array :
302+ """Infer `pyarrow.array` from PyArrow inference exception."""
303+ dtype = process_not_inferred_dtype (ex = ex )
304+ if dtype == pa .string ():
305+ array : pa .Array = pa .array (obj = [str (x ) for x in values ], type = dtype , safe = True )
306+ else :
307+ raise ex # pragma: no cover
308+ return array
309+
310+
264311def athena_types_from_pandas (
265312 df : pd .DataFrame , index : bool , dtype : Optional [Dict [str , str ]] = None , index_left : bool = False
266313) -> Dict [str , str ]:
0 commit comments