11import json
22import logging
3+ from datetime import date , datetime
34
45import pytest
56import boto3
6- import pandas
7+ import pandas as pd
78from pyspark .sql import SparkSession
89import pg8000
910
@@ -80,7 +81,7 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name, m
8081 dates = ["date" ]
8182 if sample_name == "nano" :
8283 dates = ["date" , "time" ]
83- dataframe = pandas .read_csv (f"data_samples/{ sample_name } .csv" , parse_dates = dates , infer_datetime_format = True )
84+ dataframe = pd .read_csv (f"data_samples/{ sample_name } .csv" , parse_dates = dates , infer_datetime_format = True )
8485 dataframe ["date" ] = dataframe ["date" ].dt .date
8586 con = Redshift .generate_connection (
8687 database = "test" ,
@@ -113,6 +114,46 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name, m
113114 assert len (list (dataframe .columns )) + 1 == len (list (rows [0 ]))
114115
115116
117+ def test_to_redshift_pandas_cast (session , bucket , redshift_parameters ):
118+ df = pd .DataFrame ({
119+ "id" : [1 , 2 , 3 ],
120+ "name" : ["name1" , "name2" , "name3" ],
121+ "foo" : [None , None , None ],
122+ "boo" : [date (2020 , 1 , 1 ), None , None ],
123+ "bar" : [datetime (2021 , 1 , 1 ), None , None ]})
124+ schema = {
125+ "id" : "BIGINT" ,
126+ "name" : "VARCHAR" ,
127+ "foo" : "REAL" ,
128+ "boo" : "DATE" ,
129+ "bar" : "TIMESTAMP" }
130+ con = Redshift .generate_connection (
131+ database = "test" ,
132+ host = redshift_parameters .get ("RedshiftAddress" ),
133+ port = redshift_parameters .get ("RedshiftPort" ),
134+ user = "test" ,
135+ password = redshift_parameters .get ("RedshiftPassword" ),
136+ )
137+ path = f"s3://{ bucket } /redshift-load/"
138+ session .pandas .to_redshift (dataframe = df ,
139+ path = path ,
140+ schema = "public" ,
141+ table = "test" ,
142+ connection = con ,
143+ iam_role = redshift_parameters .get ("RedshiftRole" ),
144+ mode = "overwrite" ,
145+ preserve_index = False ,
146+ cast_columns = schema )
147+ cursor = con .cursor ()
148+ cursor .execute ("SELECT * from public.test" )
149+ rows = cursor .fetchall ()
150+ cursor .close ()
151+ con .close ()
152+ print (rows )
153+ assert len (df .index ) == len (rows )
154+ assert len (list (df .columns )) == len (list (rows [0 ]))
155+
156+
116157@pytest .mark .parametrize (
117158 "sample_name,mode,factor,diststyle,distkey,exc,sortstyle,sortkey" ,
118159 [
@@ -125,7 +166,7 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name, m
125166)
126167def test_to_redshift_pandas_exceptions (session , bucket , redshift_parameters , sample_name , mode , factor , diststyle ,
127168 distkey , sortstyle , sortkey , exc ):
128- dataframe = pandas .read_csv (f"data_samples/{ sample_name } .csv" )
169+ dataframe = pd .read_csv (f"data_samples/{ sample_name } .csv" )
129170 con = Redshift .generate_connection (
130171 database = "test" ,
131172 host = redshift_parameters .get ("RedshiftAddress" ),
0 commit comments