1414import sqlalchemy_redshift .dialect # type: ignore
1515from sqlalchemy .sql .visitors import VisitableType # type: ignore
1616
17- from awswrangler import exceptions
17+ from awswrangler import _utils , exceptions
1818
1919_logger : logging .Logger = logging .getLogger (__name__ )
2020
@@ -44,14 +44,14 @@ def athena2pyarrow(dtype: str) -> pa.DataType: # pylint: disable=too-many-retur
4444 return pa .date32 ()
4545 if dtype in ("binary" or "varbinary" ):
4646 return pa .binary ()
47- if dtype .startswith ("decimal" ):
47+ if dtype .startswith ("decimal" ) is True :
4848 precision , scale = dtype .replace ("decimal(" , "" ).replace (")" , "" ).split (sep = "," )
4949 return pa .decimal128 (precision = int (precision ), scale = int (scale ))
50- if dtype .startswith ("array" ):
51- return pa .large_list ( athena2pyarrow (dtype = dtype [6 :- 1 ]))
52- if dtype .startswith ("struct" ):
50+ if dtype .startswith ("array" ) is True :
51+ return pa .list_ ( value_type = athena2pyarrow (dtype = dtype [6 :- 1 ]), list_size = - 1 )
52+ if dtype .startswith ("struct" ) is True :
5353 return pa .struct ([(f .split (":" , 1 )[0 ], athena2pyarrow (f .split (":" , 1 )[1 ])) for f in dtype [7 :- 1 ].split ("," )])
54- if dtype .startswith ("map" ): # pragma: no cover
54+ if dtype .startswith ("map" ) is True : # pragma: no cover
5555 return pa .map_ (athena2pyarrow (dtype [4 :- 1 ].split ("," , 1 )[0 ]), athena2pyarrow (dtype [4 :- 1 ].split ("," , 1 )[1 ]))
5656 raise exceptions .UnsupportedType (f"Unsupported Athena type: { dtype } " ) # pragma: no cover
5757
@@ -396,7 +396,7 @@ def pyarrow_schema_from_pandas(
396396 )
397397 for k , v in casts .items ():
398398 if (k in df .columns ) and (k not in ignore ):
399- columns_types [k ] = athena2pyarrow (v )
399+ columns_types [k ] = athena2pyarrow (dtype = v )
400400 columns_types = {k : v for k , v in columns_types .items () if v is not None }
401401 _logger .debug ("columns_types: %s" , columns_types )
402402 return pa .schema (fields = columns_types )
@@ -417,47 +417,67 @@ def athena_types_from_pyarrow_schema(
417417
418418def cast_pandas_with_athena_types (df : pd .DataFrame , dtype : Dict [str , str ]) -> pd .DataFrame :
419419 """Cast columns in a Pandas DataFrame."""
420+ mutable_ensured : bool = False
420421 for col , athena_type in dtype .items ():
421422 if (
422423 (col in df .columns )
423- and (not athena_type .startswith ("array" ))
424- and (not athena_type .startswith ("struct" ))
425- and (not athena_type .startswith ("map" ))
424+ and (athena_type .startswith ("array" ) is False )
425+ and (athena_type .startswith ("struct" ) is False )
426+ and (athena_type .startswith ("map" ) is False )
426427 ):
427- pandas_type : str = athena2pandas (dtype = athena_type )
428- if pandas_type == "datetime64" :
429- df [col ] = pd .to_datetime (df [col ])
430- elif pandas_type == "date" :
431- df [col ] = pd .to_datetime (df [col ]).dt .date .replace (to_replace = {pd .NaT : None })
432- elif pandas_type == "bytes" :
433- df [col ] = df [col ].astype ("string" ).str .encode (encoding = "utf-8" ).replace (to_replace = {pd .NA : None })
434- elif pandas_type == "decimal" :
435- df [col ] = (
436- df [col ]
437- .astype ("string" )
438- .apply (lambda x : Decimal (str (x )) if str (x ) not in ("" , "none" , "None" , " " , "<NA>" ) else None )
439- )
440- elif pandas_type == "string" :
441- curr_type : str = str (df [col ].dtypes )
442- if curr_type .lower ().startswith ("int" ) is True :
443- df [col ] = df [col ].astype (str ).astype ("string" )
444- elif curr_type .startswith ("float" ) is True :
445- df [col ] = df [col ].astype (str ).astype ("string" )
446- elif curr_type in ("object" , "category" ):
447- df [col ] = df [col ].astype (str ).astype ("string" )
448- else :
449- df [col ] = df [col ].astype ("string" )
450- else :
451- try :
452- df [col ] = df [col ].astype (pandas_type )
453- except TypeError as ex :
454- if "object cannot be converted to an IntegerDtype" not in str (ex ):
455- raise ex # pragma: no cover
456- df [col ] = (
457- df [col ]
458- .apply (lambda x : int (x ) if str (x ) not in ("" , "none" , "None" , " " , "<NA>" ) else None )
459- .astype (pandas_type )
460- )
428+ desired_type : str = athena2pandas (dtype = athena_type )
429+ current_type : str = _normalize_pandas_dtype_name (dtype = str (df [col ].dtypes ))
430+ if desired_type != current_type : # Needs conversion
431+ _logger .debug ("current_type: %s -> desired_type: %s" , current_type , desired_type )
432+ if mutable_ensured is False :
433+ df = _utils .ensure_df_is_mutable (df = df )
434+ mutable_ensured = True
435+ _cast_pandas_column (df = df , col = col , current_type = current_type , desired_type = desired_type )
436+
437+ return df
438+
439+
440+ def _normalize_pandas_dtype_name (dtype : str ) -> str :
441+ if dtype .startswith ("datetime64" ) is True :
442+ return "datetime64"
443+ if dtype .startswith ("decimal" ) is True :
444+ return "decimal" # pragma: no cover
445+ return dtype
446+
447+
448+ def _cast_pandas_column (df : pd .DataFrame , col : str , current_type : str , desired_type : str ) -> pd .DataFrame :
449+ if desired_type == "datetime64" :
450+ df [col ] = pd .to_datetime (df [col ])
451+ elif desired_type == "date" :
452+ df [col ] = pd .to_datetime (df [col ]).dt .date .replace (to_replace = {pd .NaT : None })
453+ elif desired_type == "bytes" :
454+ df [col ] = df [col ].astype ("string" ).str .encode (encoding = "utf-8" ).replace (to_replace = {pd .NA : None })
455+ elif desired_type == "decimal" :
456+ df [col ] = (
457+ df [col ]
458+ .astype ("string" )
459+ .apply (lambda x : Decimal (str (x )) if str (x ) not in ("" , "none" , "None" , " " , "<NA>" ) else None )
460+ )
461+ elif desired_type == "string" :
462+ if current_type .lower ().startswith ("int" ) is True :
463+ df [col ] = df [col ].astype (str ).astype ("string" )
464+ elif current_type .startswith ("float" ) is True :
465+ df [col ] = df [col ].astype (str ).astype ("string" )
466+ elif current_type in ("object" , "category" ):
467+ df [col ] = df [col ].astype (str ).astype ("string" )
468+ else :
469+ df [col ] = df [col ].astype ("string" )
470+ else :
471+ try :
472+ df [col ] = df [col ].astype (desired_type )
473+ except TypeError as ex :
474+ if "object cannot be converted to an IntegerDtype" not in str (ex ):
475+ raise ex # pragma: no cover
476+ df [col ] = (
477+ df [col ]
478+ .apply (lambda x : int (x ) if str (x ) not in ("" , "none" , "None" , " " , "<NA>" ) else None )
479+ .astype (desired_type )
480+ )
461481 return df
462482
463483
0 commit comments