11from typing import TYPE_CHECKING , Union , List , Dict , Tuple , Any
2- from logging import getLogger , Logger
2+ from logging import getLogger , Logger , INFO
33import json
44import warnings
55
66import pg8000 # type: ignore
7+ from pg8000 import ProgrammingError # type: ignore
78import pymysql # type: ignore
89import pandas as pd # type: ignore
910from boto3 import client # type: ignore
11+ import tenacity # type: ignore
1012
1113from awswrangler import data_types
1214from awswrangler .exceptions import InvalidEngine , InvalidDataframeType , AuroraLoadError
@@ -134,7 +136,7 @@ def load_table(dataframe: pd.DataFrame,
134136 schema_name : str ,
135137 table_name : str ,
136138 connection : Any ,
137- num_files ,
139+ num_files : int ,
138140 mode : str = "append" ,
139141 preserve_index : bool = False ,
140142 engine : str = "mysql" ,
@@ -156,6 +158,54 @@ def load_table(dataframe: pd.DataFrame,
156158 :param region: AWS S3 bucket region (Required only for postgres engine)
157159 :return: None
158160 """
161+ if "postgres" in engine .lower ():
162+ Aurora .load_table_postgres (dataframe = dataframe ,
163+ dataframe_type = dataframe_type ,
164+ load_paths = load_paths ,
165+ schema_name = schema_name ,
166+ table_name = table_name ,
167+ connection = connection ,
168+ mode = mode ,
169+ preserve_index = preserve_index ,
170+ region = region )
171+ elif "mysql" in engine .lower ():
172+ Aurora .load_table_mysql (dataframe = dataframe ,
173+ dataframe_type = dataframe_type ,
174+ manifest_path = load_paths [0 ],
175+ schema_name = schema_name ,
176+ table_name = table_name ,
177+ connection = connection ,
178+ mode = mode ,
179+ preserve_index = preserve_index ,
180+ num_files = num_files )
181+ else :
182+ raise InvalidEngine (f"{ engine } is not a valid engine. Please use 'mysql' or 'postgres'!" )
183+
184+ @staticmethod
185+ def load_table_postgres (dataframe : pd .DataFrame ,
186+ dataframe_type : str ,
187+ load_paths : List [str ],
188+ schema_name : str ,
189+ table_name : str ,
190+ connection : Any ,
191+ mode : str = "append" ,
192+ preserve_index : bool = False ,
193+ region : str = "us-east-1" ):
194+ """
195+ Load text/CSV files into a Aurora table using a manifest file.
196+ Creates the table if necessary.
197+
198+ :param dataframe: Pandas or Spark Dataframe
199+ :param dataframe_type: "pandas" or "spark"
200+ :param load_paths: S3 paths to be loaded (E.g. S3://...)
201+ :param schema_name: Aurora schema
202+ :param table_name: Aurora table name
203+ :param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
204+ :param mode: append or overwrite
205+ :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
206+ :param region: AWS S3 bucket region (Required only for postgres engine)
207+ :return: None
208+ """
159209 with connection .cursor () as cursor :
160210 if mode == "overwrite" :
161211 Aurora ._create_table (cursor = cursor ,
@@ -164,30 +214,94 @@ def load_table(dataframe: pd.DataFrame,
164214 schema_name = schema_name ,
165215 table_name = table_name ,
166216 preserve_index = preserve_index ,
167- engine = engine )
168- for path in load_paths :
169- sql = Aurora ._get_load_sql (path = path ,
170- schema_name = schema_name ,
171- table_name = table_name ,
172- engine = engine ,
173- region = region )
174- logger .debug (sql )
217+ engine = "postgres" )
218+ connection .commit ()
219+ logger .debug ("CREATE TABLE committed." )
220+ for path in load_paths :
221+ Aurora ._load_object_postgres_with_retry (connection = connection ,
222+ schema_name = schema_name ,
223+ table_name = table_name ,
224+ path = path ,
225+ region = region )
226+
227+ @staticmethod
228+ @tenacity .retry (retry = tenacity .retry_if_exception_type (exception_types = ProgrammingError ),
229+ wait = tenacity .wait_random_exponential (multiplier = 0.5 ),
230+ stop = tenacity .stop_after_attempt (max_attempt_number = 5 ),
231+ reraise = True ,
232+ after = tenacity .after_log (logger , INFO ))
233+ def _load_object_postgres_with_retry (connection : Any , schema_name : str , table_name : str , path : str ,
234+ region : str ) -> None :
235+ with connection .cursor () as cursor :
236+ sql = Aurora ._get_load_sql (path = path ,
237+ schema_name = schema_name ,
238+ table_name = table_name ,
239+ engine = "postgres" ,
240+ region = region )
241+ logger .debug (sql )
242+ try :
175243 cursor .execute (sql )
244+ except ProgrammingError as ex :
245+ if "The file has been modified" in str (ex ):
246+ connection .rollback ()
247+ raise ex
248+ connection .commit ()
249+ logger .debug (f"Load committed for: { path } ." )
176250
177- connection .commit ()
178- logger .debug ("Load committed." )
251+ @staticmethod
252+ def load_table_mysql (dataframe : pd .DataFrame ,
253+ dataframe_type : str ,
254+ manifest_path : str ,
255+ schema_name : str ,
256+ table_name : str ,
257+ connection : Any ,
258+ num_files : int ,
259+ mode : str = "append" ,
260+ preserve_index : bool = False ):
261+ """
262+ Load text/CSV files into a Aurora table using a manifest file.
263+ Creates the table if necessary.
179264
180- if "mysql" in engine .lower ():
181- with connection .cursor () as cursor :
182- sql = ("-- AWS DATA WRANGLER\n "
183- f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history "
184- f"WHERE load_prefix = '{ path } '" )
185- logger .debug (sql )
186- cursor .execute (sql )
187- num_files_loaded = cursor .fetchall ()[0 ][0 ]
188- if num_files_loaded != (num_files + 1 ):
189- raise AuroraLoadError (
190- f"Missing files to load. { num_files_loaded } files counted. { num_files + 1 } expected." )
265+ :param dataframe: Pandas or Spark Dataframe
266+ :param dataframe_type: "pandas" or "spark"
267+ :param manifest_path: S3 manifest path to be loaded (E.g. S3://...)
268+ :param schema_name: Aurora schema
269+ :param table_name: Aurora table name
270+ :param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
271+ :param num_files: Number of files to be loaded
272+ :param mode: append or overwrite
273+ :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
274+ :return: None
275+ """
276+ with connection .cursor () as cursor :
277+ if mode == "overwrite" :
278+ Aurora ._create_table (cursor = cursor ,
279+ dataframe = dataframe ,
280+ dataframe_type = dataframe_type ,
281+ schema_name = schema_name ,
282+ table_name = table_name ,
283+ preserve_index = preserve_index ,
284+ engine = "mysql" )
285+ sql = Aurora ._get_load_sql (path = manifest_path ,
286+ schema_name = schema_name ,
287+ table_name = table_name ,
288+ engine = "mysql" )
289+ logger .debug (sql )
290+ cursor .execute (sql )
291+ logger .debug (f"Load done for: { manifest_path } " )
292+ connection .commit ()
293+ logger .debug ("Load committed." )
294+
295+ with connection .cursor () as cursor :
296+ sql = ("-- AWS DATA WRANGLER\n "
297+ f"SELECT COUNT(*) as num_files_loaded FROM mysql.aurora_s3_load_history "
298+ f"WHERE load_prefix = '{ manifest_path } '" )
299+ logger .debug (sql )
300+ cursor .execute (sql )
301+ num_files_loaded = cursor .fetchall ()[0 ][0 ]
302+ if num_files_loaded != (num_files + 1 ):
303+ raise AuroraLoadError (
304+ f"Missing files to load. { num_files_loaded } files counted. { num_files + 1 } expected." )
191305
192306 @staticmethod
193307 def _parse_path (path ):
0 commit comments