22import multiprocessing as mp
33import logging
44from math import floor
5+ import copy
6+ import csv
57
68import pandas
79import pyarrow
810from pyarrow import parquet
911
10- from awswrangler .exceptions import UnsupportedWriteMode , UnsupportedFileFormat , AthenaQueryError , EmptyS3Object
12+ from awswrangler .exceptions import UnsupportedWriteMode , UnsupportedFileFormat , AthenaQueryError , EmptyS3Object , LineTerminatorNotFound
1113from awswrangler .utils import calculate_bounders
1214from awswrangler import s3
1315
@@ -41,7 +43,7 @@ def read_csv(
4143 sep = "," ,
4244 lineterminator = "\n " ,
4345 quotechar = '"' ,
44- quoting = 0 ,
46+ quoting = csv . QUOTE_MINIMAL ,
4547 escapechar = None ,
4648 parse_dates = False ,
4749 infer_datetime_format = False ,
@@ -119,7 +121,7 @@ def _read_csv_iterator(
119121 sep = "," ,
120122 lineterminator = "\n " ,
121123 quotechar = '"' ,
122- quoting = 0 ,
124+ quoting = csv . QUOTE_MINIMAL ,
123125 escapechar = None ,
124126 parse_dates = False ,
125127 infer_datetime_format = False ,
@@ -177,38 +179,38 @@ def _read_csv_iterator(
177179 bounders_len = len (bounders )
178180 count = 0
179181 forgotten_bytes = 0
180- cols_names = None
181182 for ini , end in bounders :
182183 count += 1
184+
183185 ini -= forgotten_bytes
184186 end -= 1 # Range is inclusive, contrary to Python's List
185187 bytes_range = "bytes={}-{}" .format (ini , end )
186188 logger .debug (f"bytes_range: { bytes_range } " )
187189 body = client_s3 .get_object (Bucket = bucket_name , Key = key_path , Range = bytes_range )["Body" ]\
188190 .read ()\
189- .decode (encoding , errors = "ignore " )
191+ .decode ("utf-8 " )
190192 chunk_size = len (body )
191193 logger .debug (f"chunk_size: { chunk_size } " )
192- if body [0 ] == lineterminator :
193- first_char = 1
194- else :
195- first_char = 0
196- if (count == 1 ) and (count == bounders_len ):
197- last_break_line_idx = chunk_size
198- elif count == 1 : # first chunk
199- last_break_line_idx = body .rindex (lineterminator )
200- forgotten_bytes = chunk_size - last_break_line_idx
194+
195+ if count == 1 : # first chunk
196+ last_char = Pandas ._find_terminator (
197+ body = body ,
198+ quoting = quoting ,
199+ quotechar = quotechar ,
200+ lineterminator = lineterminator )
201+ forgotten_bytes = len (body [last_char :].encode ("utf-8" ))
201202 elif count == bounders_len : # Last chunk
202- header = None
203- names = cols_names
204- last_break_line_idx = chunk_size
203+ last_char = chunk_size
205204 else :
206- header = None
207- names = cols_names
208- last_break_line_idx = body .rindex (lineterminator )
209- forgotten_bytes = chunk_size - last_break_line_idx
205+ last_char = Pandas ._find_terminator (
206+ body = body ,
207+ quoting = quoting ,
208+ quotechar = quotechar ,
209+ lineterminator = lineterminator )
210+ forgotten_bytes = len (body [last_char :].encode ("utf-8" ))
211+
210212 df = pandas .read_csv (
211- StringIO (body [first_char : last_break_line_idx ]),
213+ StringIO (body [: last_char ]),
212214 header = header ,
213215 names = names ,
214216 sep = sep ,
@@ -223,7 +225,64 @@ def _read_csv_iterator(
223225 )
224226 yield df
225227 if count == 1 : # first chunk
226- cols_names = df .columns
228+ names = df .columns
229+ header = None
230+
231+ @staticmethod
232+ def _find_terminator (body , quoting , quotechar , lineterminator ):
233+ """
234+ Find for any suspicious of line terminator (From end to start)
235+ :param body: String
236+ :param quoting: Same as pandas.read_csv()
237+ :param quotechar: Same as pandas.read_csv()
238+ :param lineterminator: Same as pandas.read_csv()
239+ :return: The index of the suspect line terminator
240+ """
241+ try :
242+ if quoting == csv .QUOTE_ALL :
243+ index = body .rindex (lineterminator )
244+ while True :
245+ i = 0
246+ while True :
247+ i += 1
248+ if index + i <= len (body ) - 1 :
249+ c = body [index + i ]
250+ if c == "," :
251+ pass
252+ elif c == quotechar :
253+ right = True
254+ break
255+ else :
256+ right = False
257+ break
258+ else :
259+ right = True
260+ break
261+ i = 0
262+ while True :
263+ i += 1
264+ if index - i >= 0 :
265+ c = body [index - i ]
266+ if c == "," :
267+ pass
268+ elif c == quotechar :
269+ left = True
270+ break
271+ else :
272+ left = False
273+ break
274+ else :
275+ left = True
276+ break
277+
278+ if right and left :
279+ break
280+ index = body [:index ].rindex (lineterminator )
281+ else :
282+ index = body .rindex (lineterminator )
283+ except ValueError :
284+ raise LineTerminatorNotFound ()
285+ return index
227286
228287 @staticmethod
229288 def _read_csv_once (
@@ -293,7 +352,7 @@ def read_sql_athena(self,
293352 Executes any SQL query on AWS Athena and return a Dataframe of the result.
294353 P.S. If max_result_size is passed, then a iterator of Dataframes is returned.
295354 :param sql: SQL Query
296- :param database: Glue/Athena Databease
355+ :param database: Glue/Athena Database
297356 :param s3_output: AWS S3 path
298357 :param max_result_size: Max number of bytes on each request to S3
299358 :return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None
@@ -318,8 +377,14 @@ def read_sql_athena(self,
318377 message_error = f"Query error: { reason } "
319378 raise AthenaQueryError (message_error )
320379 else :
380+ dtype , parse_dates = self ._session .athena .get_query_dtype (
381+ query_execution_id = query_execution_id )
321382 path = f"{ s3_output } { query_execution_id } .csv"
322- ret = self .read_csv (path = path , max_result_size = max_result_size )
383+ ret = self .read_csv (path = path ,
384+ dtype = dtype ,
385+ parse_dates = parse_dates ,
386+ quoting = csv .QUOTE_ALL ,
387+ max_result_size = max_result_size )
323388 return ret
324389
325390 def to_csv (
@@ -623,11 +688,18 @@ def write_csv_dataframe(dataframe, path, preserve_index, fs):
623688 f .write (csv_buffer )
624689
625690 @staticmethod
626- def write_parquet_dataframe (dataframe ,
627- path ,
628- preserve_index ,
629- fs ,
630- cast_columns = None ):
691+ def write_parquet_dataframe (dataframe , path , preserve_index , fs ,
692+ cast_columns ):
693+ if not cast_columns :
694+ cast_columns = {}
695+ casted_in_pandas = []
696+ dtypes = copy .deepcopy (dataframe .dtypes .to_dict ())
697+ for name , dtype in dtypes .items ():
698+ if str (dtype ) == "Int64" :
699+ dataframe [name ] = dataframe [name ].astype ("float64" )
700+ casted_in_pandas .append (name )
701+ cast_columns [name ] = "int64"
702+ logger .debug (f"Casting column { name } Int64 to float64" )
631703 table = pyarrow .Table .from_pandas (df = dataframe ,
632704 preserve_index = preserve_index ,
633705 safe = False )
@@ -636,13 +708,15 @@ def write_parquet_dataframe(dataframe,
636708 col_index = table .column_names .index (col_name )
637709 table = table .set_column (col_index ,
638710 table .column (col_name ).cast (dtype ))
639- logger .debug (f" { col_name } - { col_index } - { dtype } " )
640- logger . debug ( f"table.schema: \n { table . schema } " )
711+ logger .debug (
712+ f"Casting column { col_name } ( { col_index } ) to { dtype } " )
641713 with fs .open (path , "wb" ) as f :
642714 parquet .write_table (table ,
643715 f ,
644716 coerce_timestamps = "ms" ,
645717 flavor = "spark" )
718+ for col in casted_in_pandas :
719+ dataframe [col ] = dataframe [col ].astype ("Int64" )
646720
647721 def to_redshift (
648722 self ,
0 commit comments