@@ -86,14 +86,22 @@ def redshift_parameters(cloudformation_outputs):
8686def test_to_redshift_pandas (session , bucket , redshift_parameters , sample_name ,
8787 mode , factor , diststyle , distkey , sortstyle ,
8888 sortkey ):
89+ if sample_name == "micro" :
90+ dates = ["date" ]
91+ if sample_name == "small" :
92+ dates = ["date" ]
93+ if sample_name == "nano" :
94+ dates = ["date" , "time" ]
95+ dataframe = pandas .read_csv (f"data_samples/{ sample_name } .csv" ,
96+ parse_dates = dates ,
97+ infer_datetime_format = True )
8998 con = Redshift .generate_connection (
9099 database = "test" ,
91100 host = redshift_parameters .get ("RedshiftAddress" ),
92101 port = redshift_parameters .get ("RedshiftPort" ),
93102 user = "test" ,
94103 password = redshift_parameters .get ("RedshiftPassword" ),
95104 )
96- dataframe = pandas .read_csv (f"data_samples/{ sample_name } .csv" )
97105 path = f"s3://{ bucket } /redshift-load/"
98106 session .pandas .to_redshift (
99107 dataframe = dataframe ,
@@ -110,11 +118,12 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name,
110118 preserve_index = False ,
111119 )
112120 cursor = con .cursor ()
113- cursor .execute ("SELECT COUNT(*) as counter from public.test" )
114- counter = cursor .fetchall ()[ 0 ][ 0 ]
121+ cursor .execute ("SELECT * from public.test" )
122+ rows = cursor .fetchall ()
115123 cursor .close ()
116124 con .close ()
117- assert len (dataframe .index ) * factor == counter
125+ assert len (dataframe .index ) * factor == len (rows )
126+ assert len (list (dataframe .columns )) == len (list (rows [0 ]))
118127
119128
120129@pytest .mark .parametrize (
@@ -135,14 +144,14 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name,
135144def test_to_redshift_pandas_exceptions (session , bucket , redshift_parameters ,
136145 sample_name , mode , factor , diststyle ,
137146 distkey , sortstyle , sortkey , exc ):
147+ dataframe = pandas .read_csv (f"data_samples/{ sample_name } .csv" )
138148 con = Redshift .generate_connection (
139149 database = "test" ,
140150 host = redshift_parameters .get ("RedshiftAddress" ),
141151 port = redshift_parameters .get ("RedshiftPort" ),
142152 user = "test" ,
143153 password = redshift_parameters .get ("RedshiftPassword" ),
144154 )
145- dataframe = pandas .read_csv (f"data_samples/{ sample_name } .csv" )
146155 path = f"s3://{ bucket } /redshift-load/"
147156 with pytest .raises (exc ):
148157 assert session .pandas .to_redshift (
@@ -180,7 +189,20 @@ def test_to_redshift_spark(session, bucket, redshift_parameters, sample_name,
180189 mode , factor , diststyle , distkey , sortstyle ,
181190 sortkey ):
182191 path = f"data_samples/{ sample_name } .csv"
183- dataframe = session .spark .read_csv (path = path )
192+ if sample_name == "micro" :
193+ schema = "id SMALLINT, name STRING, value FLOAT, date TIMESTAMP"
194+ timestamp_format = "yyyy-MM-dd"
195+ elif sample_name == "small" :
196+ schema = "id BIGINT, name STRING, date DATE"
197+ timestamp_format = "dd-MM-yy"
198+ elif sample_name == "nano" :
199+ schema = "id INTEGER, name STRING, value DOUBLE, date TIMESTAMP, time TIMESTAMP"
200+ timestamp_format = "yyyy-MM-dd"
201+ dataframe = session .spark .read_csv (path = path ,
202+ schema = schema ,
203+ timestampFormat = timestamp_format ,
204+ dateFormat = timestamp_format ,
205+ header = True )
184206 con = Redshift .generate_connection (
185207 database = "test" ,
186208 host = redshift_parameters .get ("RedshiftAddress" ),
@@ -203,11 +225,12 @@ def test_to_redshift_spark(session, bucket, redshift_parameters, sample_name,
203225 min_num_partitions = 2 ,
204226 )
205227 cursor = con .cursor ()
206- cursor .execute ("SELECT COUNT(*) as counter from public.test" )
207- counter = cursor .fetchall ()[ 0 ][ 0 ]
228+ cursor .execute ("SELECT * from public.test" )
229+ rows = cursor .fetchall ()
208230 cursor .close ()
209231 con .close ()
210- assert dataframe .count () * factor == counter
232+ assert (dataframe .count () * factor ) == len (rows )
233+ assert len (list (dataframe .columns )) == len (list (rows [0 ]))
211234
212235
213236@pytest .mark .parametrize (
0 commit comments