@@ -56,6 +56,26 @@ def parameters(cloudformation_outputs):
5656 yield parameters
5757
5858
59+ @pytest .fixture (scope = "module" )
60+ def glue_database (cloudformation_outputs ):
61+ yield cloudformation_outputs ["GlueDatabaseName" ]
62+
63+
64+ @pytest .fixture (scope = "module" )
65+ def external_schema (cloudformation_outputs , parameters , glue_database ):
66+ region = cloudformation_outputs .get ("Region" )
67+ sql = f"""
68+ CREATE EXTERNAL SCHEMA IF NOT EXISTS aws_data_wrangler_external FROM data catalog
69+ DATABASE '{ glue_database } '
70+ IAM_ROLE '{ parameters ["redshift" ]["role" ]} '
71+ REGION '{ region } ';
72+ """
73+ engine = wr .catalog .get_engine (connection = f"aws-data-wrangler-redshift" )
74+ with engine .connect () as con :
75+ con .execute (sql )
76+ yield "aws_data_wrangler_external"
77+
78+
5979@pytest .mark .parametrize ("db_type" , ["mysql" , "redshift" , "postgresql" ])
6080def test_sql (parameters , db_type ):
6181 df = get_df ()
@@ -305,3 +325,26 @@ def test_redshift_exceptions(bucket, parameters, diststyle, distkey, sortstyle,
305325 index = False ,
306326 )
307327 wr .s3 .delete_objects (path = path )
328+
329+
330+ def test_redshift_spectrum (bucket , glue_database , external_schema ):
331+ df = pd .DataFrame ({"id" : [1 , 2 , 3 , 4 , 5 ], "col_str" : ["foo" , None , "bar" , None , "xoo" ], "par_int" : [0 , 1 , 0 , 1 , 1 ]})
332+ path = f"s3://{ bucket } /test_redshift_spectrum/"
333+ paths = wr .s3 .to_parquet (
334+ df = df ,
335+ path = path ,
336+ database = glue_database ,
337+ table = "test_redshift_spectrum" ,
338+ mode = "overwrite" ,
339+ index = False ,
340+ dataset = True ,
341+ partition_cols = ["par_int" ],
342+ )["paths" ]
343+ wr .s3 .wait_objects_exist (paths = paths , use_threads = False )
344+ engine = wr .catalog .get_engine (connection = f"aws-data-wrangler-redshift" )
345+ with engine .connect () as con :
346+ cursor = con .execute (f"SELECT * FROM { external_schema } .test_redshift_spectrum" )
347+ rows = cursor .fetchall ()
348+ assert len (rows ) == len (df .index )
349+ for row in rows :
350+ assert len (row ) == len (df .columns )
0 commit comments