1+ from typing import List , Tuple , Dict , Callable
12import logging
23from datetime import datetime , date
34
4- import pyarrow
5+ import pyarrow as pa # type: ignore
6+ import pandas as pd # type: ignore
57
6- from awswrangler .exceptions import UnsupportedType , UndetectedType
8+ from awswrangler .exceptions import UnsupportedType , UndetectedType # type: ignore
79
810logger = logging .getLogger (__name__ )
911
1012
11- def athena2pandas (dtype ) :
13+ def athena2pandas (dtype : str ) -> str :
1214 dtype = dtype .lower ()
1315 if dtype in ["int" , "integer" , "bigint" , "smallint" , "tinyint" ]:
1416 return "Int64"
@@ -28,7 +30,7 @@ def athena2pandas(dtype):
2830 raise UnsupportedType (f"Unsupported Athena type: { dtype } " )
2931
3032
31- def athena2pyarrow (dtype ) :
33+ def athena2pyarrow (dtype : str ) -> str :
3234 dtype = dtype .lower ()
3335 if dtype == "tinyint" :
3436 return "int8"
@@ -54,7 +56,7 @@ def athena2pyarrow(dtype):
5456 raise UnsupportedType (f"Unsupported Athena type: { dtype } " )
5557
5658
57- def athena2python (dtype ) :
59+ def athena2python (dtype : str ) -> type :
5860 dtype = dtype .lower ()
5961 if dtype in ["int" , "integer" , "bigint" , "smallint" , "tinyint" ]:
6062 return int
@@ -72,7 +74,7 @@ def athena2python(dtype):
7274 raise UnsupportedType (f"Unsupported Athena type: { dtype } " )
7375
7476
75- def athena2redshift (dtype ) :
77+ def athena2redshift (dtype : str ) -> str :
7678 dtype = dtype .lower ()
7779 if dtype == "smallint" :
7880 return "SMALLINT"
@@ -96,7 +98,7 @@ def athena2redshift(dtype):
9698 raise UnsupportedType (f"Unsupported Athena type: { dtype } " )
9799
98100
99- def pandas2athena (dtype ) :
101+ def pandas2athena (dtype : str ) -> str :
100102 dtype = dtype .lower ()
101103 if dtype == "int32" :
102104 return "int"
@@ -116,7 +118,7 @@ def pandas2athena(dtype):
116118 raise UnsupportedType (f"Unsupported Pandas type: { dtype } " )
117119
118120
119- def pandas2redshift (dtype ) :
121+ def pandas2redshift (dtype : str ) -> str :
120122 dtype = dtype .lower ()
121123 if dtype == "int32" :
122124 return "INTEGER"
@@ -136,7 +138,7 @@ def pandas2redshift(dtype):
136138 raise UnsupportedType ("Unsupported Pandas type: " + dtype )
137139
138140
139- def pyarrow2athena (dtype ) :
141+ def pyarrow2athena (dtype : pa . types ) -> str :
140142 dtype_str = str (dtype ).lower ()
141143 if dtype_str == "int8" :
142144 return "tinyint"
@@ -167,7 +169,7 @@ def pyarrow2athena(dtype):
167169 raise UnsupportedType (f"Unsupported Pyarrow type: { dtype } " )
168170
169171
170- def pyarrow2redshift (dtype ) :
172+ def pyarrow2redshift (dtype : pa . types ) -> str :
171173 dtype_str = str (dtype ).lower ()
172174 if dtype_str == "int16" :
173175 return "SMALLINT"
@@ -191,25 +193,25 @@ def pyarrow2redshift(dtype):
191193 raise UnsupportedType (f"Unsupported Pyarrow type: { dtype } " )
192194
193195
194- def python2athena (python_type ) :
195- python_type = str (python_type )
196- if python_type == "<class 'int'>" :
196+ def python2athena (python_type : type ) -> str :
197+ python_type_str : str = str (python_type )
198+ if python_type_str == "<class 'int'>" :
197199 return "bigint"
198- elif python_type == "<class 'float'>" :
200+ elif python_type_str == "<class 'float'>" :
199201 return "double"
200- elif python_type == "<class 'boll'>" :
202+ elif python_type_str == "<class 'boll'>" :
201203 return "boolean"
202- elif python_type == "<class 'str'>" :
204+ elif python_type_str == "<class 'str'>" :
203205 return "string"
204- elif python_type == "<class 'datetime.datetime'>" :
206+ elif python_type_str == "<class 'datetime.datetime'>" :
205207 return "timestamp"
206- elif python_type == "<class 'datetime.date'>" :
208+ elif python_type_str == "<class 'datetime.date'>" :
207209 return "date"
208210 else :
209- raise UnsupportedType (f"Unsupported Python type: { python_type } " )
211+ raise UnsupportedType (f"Unsupported Python type: { python_type_str } " )
210212
211213
212- def redshift2athena (dtype ) :
214+ def redshift2athena (dtype : str ) -> str :
213215 dtype_str = str (dtype )
214216 if dtype_str in ["SMALLINT" , "INT2" ]:
215217 return "smallint"
@@ -233,8 +235,8 @@ def redshift2athena(dtype):
233235 raise UnsupportedType (f"Unsupported Redshift type: { dtype_str } " )
234236
235237
236- def redshift2pyarrow (dtype ) :
237- dtype_str = str (dtype )
238+ def redshift2pyarrow (dtype : str ) -> str :
239+ dtype_str : str = str (dtype )
238240 if dtype_str in ["SMALLINT" , "INT2" ]:
239241 return "int16"
240242 elif dtype_str in ["INTEGER" , "INT" , "INT4" ]:
@@ -257,7 +259,7 @@ def redshift2pyarrow(dtype):
257259 raise UnsupportedType (f"Unsupported Redshift type: { dtype_str } " )
258260
259261
260- def spark2redshift (dtype ) :
262+ def spark2redshift (dtype : str ) -> str :
261263 dtype = dtype .lower ()
262264 if dtype == "smallint" :
263265 return "SMALLINT"
@@ -281,7 +283,7 @@ def spark2redshift(dtype):
281283 raise UnsupportedType ("Unsupported Spark type: " + dtype )
282284
283285
284- def convert_schema (func , schema ) :
286+ def convert_schema (func : Callable , schema : List [ Tuple [ str , str ]]) -> Dict [ str , str ] :
285287 """
286288 Convert schema in the format of {"col name": "bigint", "col2 name": "int"}
287289 applying some data types conversion function (e.g. spark2redshift)
@@ -293,16 +295,16 @@ def convert_schema(func, schema):
293295 return {name : func (dtype ) for name , dtype in schema }
294296
295297
296- def extract_pyarrow_schema_from_pandas (dataframe ,
297- preserve_index ,
298- indexes_position = "right" ):
298+ def extract_pyarrow_schema_from_pandas (dataframe : pd . DataFrame ,
299+ preserve_index : bool ,
300+ indexes_position : str = "right" ) -> List [ Tuple [ str , str ]] :
299301 """
300302 Extract the related Pyarrow schema from any Pandas DataFrame
301303
302304 :param dataframe: Pandas Dataframe
303305 :param preserve_index: True or False
304306 :param indexes_position: "right" or "left"
305- :return: Pyarrow schema (e.g. { "col name": "bigint", "col2 name": "int"})
307+ :return: Pyarrow schema (e.g. [( "col name": "bigint"), ( "col2 name": "int")]
306308 """
307309 cols = []
308310 cols_dtypes = {}
@@ -319,8 +321,8 @@ def extract_pyarrow_schema_from_pandas(dataframe,
319321
320322 # Filling cols_dtypes and indexes
321323 indexes = []
322- for field in pyarrow .Schema .from_pandas (df = dataframe [cols ],
323- preserve_index = preserve_index ):
324+ for field in pa .Schema .from_pandas (df = dataframe [cols ],
325+ preserve_index = preserve_index ):
324326 name = str (field .name )
325327 dtype = field .type
326328 cols_dtypes [name ] = dtype
0 commit comments