1414 AthenaQueryError , EmptyS3Object , LineTerminatorNotFound , EmptyDataframe , \
1515 InvalidSerDe , InvalidCompression
1616from awswrangler .utils import calculate_bounders
17- from awswrangler import s3
17+ from awswrangler import s3 , glue
1818
1919logger = logging .getLogger (__name__ )
2020
@@ -56,6 +56,7 @@ def read_csv(
5656 parse_dates = False ,
5757 infer_datetime_format = False ,
5858 encoding = "utf-8" ,
59+ converters = None ,
5960 ):
6061 """
6162 Read CSV file from AWS S3 using optimized strategies.
@@ -76,6 +77,7 @@ def read_csv(
7677 :param parse_dates: Same as pandas.read_csv()
7778 :param infer_datetime_format: Same as pandas.read_csv()
7879 :param encoding: Same as pandas.read_csv()
80+ :param converters: Same as pandas.read_csv()
7981 :return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None
8082 """
8183 bucket_name , key_path = self ._parse_path (path )
@@ -99,7 +101,8 @@ def read_csv(
99101 escapechar = escapechar ,
100102 parse_dates = parse_dates ,
101103 infer_datetime_format = infer_datetime_format ,
102- encoding = encoding )
104+ encoding = encoding ,
105+ converters = converters )
103106 else :
104107 ret = Pandas ._read_csv_once (
105108 client_s3 = client_s3 ,
@@ -115,7 +118,8 @@ def read_csv(
115118 escapechar = escapechar ,
116119 parse_dates = parse_dates ,
117120 infer_datetime_format = infer_datetime_format ,
118- encoding = encoding )
121+ encoding = encoding ,
122+ converters = converters )
119123 return ret
120124
121125 @staticmethod
@@ -135,6 +139,7 @@ def _read_csv_iterator(
135139 parse_dates = False ,
136140 infer_datetime_format = False ,
137141 encoding = "utf-8" ,
142+ converters = None ,
138143 ):
139144 """
140145 Read CSV file from AWS S3 using optimized strategies.
@@ -156,6 +161,7 @@ def _read_csv_iterator(
156161 :param parse_dates: Same as pandas.read_csv()
157162 :param infer_datetime_format: Same as pandas.read_csv()
158163 :param encoding: Same as pandas.read_csv()
164+ :param converters: Same as pandas.read_csv()
159165 :return: Pandas Dataframe
160166 """
161167 metadata = s3 .S3 .head_object_with_retry (client = client_s3 ,
@@ -181,7 +187,8 @@ def _read_csv_iterator(
181187 escapechar = escapechar ,
182188 parse_dates = parse_dates ,
183189 infer_datetime_format = infer_datetime_format ,
184- encoding = encoding )
190+ encoding = encoding ,
191+ converters = converters )
185192 else :
186193 bounders = calculate_bounders (num_items = total_size ,
187194 max_size = max_result_size )
@@ -234,7 +241,7 @@ def _read_csv_iterator(
234241 lineterminator = lineterminator ,
235242 dtype = dtype ,
236243 encoding = encoding ,
237- )
244+ converters = converters )
238245 yield df
239246 if count == 1 : # first chunk
240247 names = df .columns
@@ -352,6 +359,7 @@ def _read_csv_once(
352359 parse_dates = False ,
353360 infer_datetime_format = False ,
354361 encoding = None ,
362+ converters = None ,
355363 ):
356364 """
357365 Read CSV file from AWS S3 using optimized strategies.
@@ -372,6 +380,7 @@ def _read_csv_once(
372380 :param parse_dates: Same as pandas.read_csv()
373381 :param infer_datetime_format: Same as pandas.read_csv()
374382 :param encoding: Same as pandas.read_csv()
383+ :param converters: Same as pandas.read_csv()
375384 :return: Pandas Dataframe
376385 """
377386 buff = BytesIO ()
@@ -392,6 +401,7 @@ def _read_csv_once(
392401 lineterminator = lineterminator ,
393402 dtype = dtype ,
394403 encoding = encoding ,
404+ converters = converters ,
395405 )
396406 buff .close ()
397407 return dataframe
@@ -425,12 +435,13 @@ def read_sql_athena(self,
425435 message_error = f"Query error: { reason } "
426436 raise AthenaQueryError (message_error )
427437 else :
428- dtype , parse_timestamps , parse_dates = self ._session .athena .get_query_dtype (
438+ dtype , parse_timestamps , parse_dates , converters = self ._session .athena .get_query_dtype (
429439 query_execution_id = query_execution_id )
430440 path = f"{ s3_output } { query_execution_id } .csv"
431441 ret = self .read_csv (path = path ,
432442 dtype = dtype ,
433443 parse_dates = parse_timestamps ,
444+ converters = converters ,
434445 quoting = csv .QUOTE_ALL ,
435446 max_result_size = max_result_size )
436447 if max_result_size is None :
@@ -848,18 +859,21 @@ def write_parquet_dataframe(dataframe,
848859 if str (dtype ) == "Int64" :
849860 dataframe [name ] = dataframe [name ].astype ("float64" )
850861 casted_in_pandas .append (name )
851- cast_columns [name ] = "int64 "
862+ cast_columns [name ] = "bigint "
852863 logger .debug (f"Casting column { name } Int64 to float64" )
853864 table = pyarrow .Table .from_pandas (df = dataframe ,
854865 preserve_index = preserve_index ,
855866 safe = False )
856867 if cast_columns :
857868 for col_name , dtype in cast_columns .items ():
858869 col_index = table .column_names .index (col_name )
859- table = table .set_column (col_index ,
860- table .column (col_name ).cast (dtype ))
870+ pyarrow_dtype = glue .Glue .type_athena2pyarrow (dtype )
871+ table = table .set_column (
872+ col_index ,
873+ table .column (col_name ).cast (pyarrow_dtype ))
861874 logger .debug (
862- f"Casting column { col_name } ({ col_index } ) to { dtype } " )
875+ f"Casting column { col_name } ({ col_index } ) to { dtype } ({ pyarrow_dtype } )"
876+ )
863877 with fs .open (path , "wb" ) as f :
864878 parquet .write_table (table ,
865879 f ,
0 commit comments