@@ -76,6 +76,11 @@ def external_schema(cloudformation_outputs, parameters, glue_database):
7676 yield "aws_data_wrangler_external"
7777
7878
79+ @pytest .fixture (scope = "module" )
80+ def kms_key_id (cloudformation_outputs ):
81+ yield cloudformation_outputs ["KmsKeyArn" ].split ("/" , 1 )[1 ]
82+
83+
7984@pytest .mark .parametrize ("db_type" , ["mysql" , "redshift" , "postgresql" ])
8085def test_sql (parameters , db_type ):
8186 df = get_df ()
@@ -386,3 +391,72 @@ def test_redshift_category(bucket, parameters):
386391 for df2 in dfs :
387392 ensure_data_types_category (df2 )
388393 wr .s3 .delete_objects (path = path )
394+
395+
396+ def test_redshift_unload_extras (bucket , parameters , kms_key_id ):
397+ table = "test_redshift_unload_extras"
398+ schema = parameters ["redshift" ]["schema" ]
399+ path = f"s3://{ bucket } /{ table } /"
400+ wr .s3 .delete_objects (path = path )
401+ engine = wr .catalog .get_engine (connection = f"aws-data-wrangler-redshift" )
402+ df = pd .DataFrame ({"id" : [1 , 2 ], "name" : ["foo" , "boo" ]})
403+ wr .db .to_sql (df = df , con = engine , name = table , schema = schema , if_exists = "replace" , index = False )
404+ paths = wr .db .unload_redshift_to_files (
405+ sql = f"SELECT * FROM { schema } .{ table } " ,
406+ path = path ,
407+ con = engine ,
408+ iam_role = parameters ["redshift" ]["role" ],
409+ region = wr .s3 .get_bucket_region (bucket ),
410+ max_file_size = 5.0 ,
411+ kms_key_id = kms_key_id ,
412+ partition_cols = ["name" ],
413+ )
414+ wr .s3 .wait_objects_exist (paths = paths )
415+ df = wr .s3 .read_parquet (path = path , dataset = True )
416+ assert len (df .index ) == 2
417+ assert len (df .columns ) == 2
418+ wr .s3 .delete_objects (path = path )
419+ df = wr .db .unload_redshift (
420+ sql = f"SELECT * FROM { schema } .{ table } " ,
421+ con = engine ,
422+ iam_role = parameters ["redshift" ]["role" ],
423+ path = path ,
424+ keep_files = False ,
425+ region = wr .s3 .get_bucket_region (bucket ),
426+ max_file_size = 5.0 ,
427+ kms_key_id = kms_key_id ,
428+ )
429+ assert len (df .index ) == 2
430+ assert len (df .columns ) == 2
431+ wr .s3 .delete_objects (path = path )
432+
433+
434+ @pytest .mark .parametrize ("db_type" , ["mysql" , "redshift" , "postgresql" ])
435+ def test_to_sql_cast (parameters , db_type ):
436+ table = "test_to_sql_cast"
437+ schema = parameters [db_type ]["schema" ]
438+ df = pd .DataFrame (
439+ {
440+ "col" : [
441+ "" .join ([str (i )[- 1 ] for i in range (1_024 )]),
442+ "" .join ([str (i )[- 1 ] for i in range (1_024 )]),
443+ "" .join ([str (i )[- 1 ] for i in range (1_024 )]),
444+ ]
445+ },
446+ dtype = "string" ,
447+ )
448+ engine = wr .catalog .get_engine (connection = f"aws-data-wrangler-{ db_type } " )
449+ wr .db .to_sql (
450+ df = df ,
451+ con = engine ,
452+ name = table ,
453+ schema = schema ,
454+ if_exists = "replace" ,
455+ index = False ,
456+ index_label = None ,
457+ chunksize = None ,
458+ method = None ,
459+ dtype = {"col" : sqlalchemy .types .VARCHAR (length = 1_024 )},
460+ )
461+ df2 = wr .db .read_sql_query (sql = f"SELECT * FROM { schema } .{ table } " , con = engine )
462+ assert df .equals (df2 )
0 commit comments