@@ -617,7 +617,7 @@ def test_store_parquet_metadata_modes(database, table, path):
617617def test_athena_ctas (path , path2 , path3 , table , table2 , database , kms_key ):
618618 df = get_df_list ()
619619 columns_types , partitions_types = wr .catalog .extract_athena_types (df = df , partition_cols = ["par0" , "par1" ])
620- assert len (columns_types ) == 16
620+ assert len (columns_types ) == 17
621621 assert len (partitions_types ) == 2
622622 with pytest .raises (wr .exceptions .InvalidArgumentValue ):
623623 wr .catalog .extract_athena_types (df = df , file_format = "avro" )
@@ -714,7 +714,6 @@ def test_athena(path, database, kms_key, workgroup0, workgroup1):
714714 keep_files = False ,
715715 )
716716 for df2 in dfs :
717- print (df2 )
718717 ensure_data_types (df = df2 )
719718 df = wr .athena .read_sql_query (
720719 sql = "SELECT * FROM __test_athena" ,
@@ -729,12 +728,12 @@ def test_athena(path, database, kms_key, workgroup0, workgroup1):
729728 wr .catalog .delete_table_if_exists (database = database , table = "__test_athena" )
730729
731730
732- def test_csv (bucket ):
731+ def test_csv (path ):
733732 session = boto3 .Session ()
734733 df = pd .DataFrame ({"id" : [1 , 2 , 3 ]})
735- path0 = f"s3:// { bucket } / test_csv0.csv"
736- path1 = f"s3:// { bucket } / test_csv1.csv"
737- path2 = f"s3:// { bucket } / test_csv2.csv"
734+ path0 = f"{ path } test_csv0.csv"
735+ path1 = f"{ path } test_csv1.csv"
736+ path2 = f"{ path } test_csv2.csv"
738737 wr .s3 .to_csv (df = df , path = path0 , index = False )
739738 wr .s3 .wait_objects_exist (paths = [path0 ])
740739 assert wr .s3 .does_object_exist (path = path0 ) is True
@@ -820,13 +819,12 @@ def test_list_by_last_modified_date(path):
820819 assert len (wr .s3 .read_json (path , last_modified_begin = begin_utc , last_modified_end = end_utc ).index ) == 6
821820
822821
823- def test_parquet (bucket ):
824- wr .s3 .delete_objects (path = f"s3://{ bucket } /test_parquet/" )
822+ def test_parquet (path ):
825823 df_file = pd .DataFrame ({"id" : [1 , 2 , 3 ]})
826- path_file = f"s3:// { bucket } /test_parquet/ test_parquet_file.parquet"
824+ path_file = f"{ path } test_parquet_file.parquet"
827825 df_dataset = pd .DataFrame ({"id" : [1 , 2 , 3 ], "partition" : ["A" , "A" , "B" ]})
828826 df_dataset ["partition" ] = df_dataset ["partition" ].astype ("category" )
829- path_dataset = f"s3:// { bucket } /test_parquet/ test_parquet_dataset"
827+ path_dataset = f"{ path } test_parquet_dataset"
830828 with pytest .raises (wr .exceptions .InvalidArgumentCombination ):
831829 wr .s3 .to_parquet (df = df_file , path = path_file , mode = "append" )
832830 with pytest .raises (wr .exceptions .InvalidCompression ):
@@ -856,7 +854,6 @@ def test_parquet(bucket):
856854 wr .s3 .to_parquet (
857855 df = df_dataset , path = path_dataset , dataset = True , partition_cols = ["partition" ], mode = "overwrite_partitions"
858856 )
859- wr .s3 .delete_objects (path = f"s3://{ bucket } /test_parquet/" )
860857
861858
862859def test_parquet_catalog (bucket , database ):
@@ -919,12 +916,12 @@ def test_parquet_catalog(bucket, database):
919916 columns_types , partitions_types = wr .s3 .read_parquet_metadata (
920917 path = f"s3://{ bucket } /test_parquet_catalog2" , dataset = True
921918 )
922- assert len (columns_types ) == 17
919+ assert len (columns_types ) == 18
923920 assert len (partitions_types ) == 2
924921 columns_types , partitions_types , partitions_values = wr .s3 .store_parquet_metadata (
925922 path = f"s3://{ bucket } /test_parquet_catalog2" , database = database , table = "test_parquet_catalog2" , dataset = True
926923 )
927- assert len (columns_types ) == 17
924+ assert len (columns_types ) == 18
928925 assert len (partitions_types ) == 2
929926 assert len (partitions_values ) == 2
930927 wr .s3 .delete_objects (path = f"s3://{ bucket } /test_parquet_catalog/" )
@@ -933,23 +930,11 @@ def test_parquet_catalog(bucket, database):
933930 assert wr .catalog .delete_table_if_exists (database = database , table = "test_parquet_catalog2" ) is True
934931
935932
936- def test_parquet_catalog_duplicated (bucket , database ):
937- path = f"s3://{ bucket } /test_parquet_catalog_dedup/"
933+ def test_parquet_catalog_duplicated (path , table , database ):
938934 df = pd .DataFrame ({"A" : [1 ], "a" : [1 ]})
939- wr .s3 .to_parquet (
940- df = df ,
941- path = path ,
942- index = False ,
943- dataset = True ,
944- mode = "overwrite" ,
945- database = database ,
946- table = "test_parquet_catalog_dedup" ,
947- )
935+ wr .s3 .to_parquet (df = df , path = path , index = False , dataset = True , mode = "overwrite" , database = database , table = table )
948936 df = wr .s3 .read_parquet (path = path )
949- assert len (df .index ) == 1
950- assert len (df .columns ) == 1
951- wr .s3 .delete_objects (path = path )
952- assert wr .catalog .delete_table_if_exists (database = database , table = "test_parquet_catalog_dedup" ) is True
937+ assert df .shape == (1 , 1 )
953938
954939
955940def test_parquet_catalog_casting (path , database ):
@@ -981,16 +966,13 @@ def test_parquet_catalog_casting(path, database):
981966 )["paths" ]
982967 wr .s3 .wait_objects_exist (paths = paths )
983968 df = wr .s3 .read_parquet (path = path )
984- assert len (df .index ) == 3
985- assert len (df .columns ) == 15
969+ assert df .shape == (3 , 16 )
986970 ensure_data_types (df = df , has_list = False )
987971 df = wr .athena .read_sql_table (table = "__test_parquet_catalog_casting" , database = database , ctas_approach = True )
988- assert len (df .index ) == 3
989- assert len (df .columns ) == 15
972+ assert df .shape == (3 , 16 )
990973 ensure_data_types (df = df , has_list = False )
991974 df = wr .athena .read_sql_table (table = "__test_parquet_catalog_casting" , database = database , ctas_approach = False )
992- assert len (df .index ) == 3
993- assert len (df .columns ) == 15
975+ assert df .shape == (3 , 16 )
994976 ensure_data_types (df = df , has_list = False )
995977 wr .s3 .delete_objects (path = path )
996978 assert wr .catalog .delete_table_if_exists (database = database , table = "__test_parquet_catalog_casting" ) is True
@@ -1278,8 +1260,7 @@ def test_parquet_validate_schema(path):
12781260 wr .s3 .read_parquet (path = path , validate_schema = True )
12791261
12801262
1281- def test_csv_dataset (bucket , database ):
1282- path = f"s3://{ bucket } /test_csv_dataset/"
1263+ def test_csv_dataset (path , database ):
12831264 with pytest .raises (wr .exceptions .UndetectedType ):
12841265 wr .s3 .to_csv (pd .DataFrame ({"A" : [None ]}), path , dataset = True , database = database , table = "test_csv_dataset" )
12851266 df = get_df_csv ()
@@ -1317,8 +1298,7 @@ def test_csv_dataset(bucket, database):
13171298 wr .s3 .delete_objects (path = paths )
13181299
13191300
1320- def test_csv_catalog (bucket , database ):
1321- path = f"s3://{ bucket } /test_csv_catalog/"
1301+ def test_csv_catalog (path , table , database ):
13221302 df = get_df_csv ()
13231303 paths = wr .s3 .to_csv (
13241304 df = df ,
@@ -1331,17 +1311,17 @@ def test_csv_catalog(bucket, database):
13311311 dataset = True ,
13321312 partition_cols = ["par0" , "par1" ],
13331313 mode = "overwrite" ,
1334- table = "test_csv_catalog" ,
1314+ table = table ,
13351315 database = database ,
13361316 )["paths" ]
13371317 wr .s3 .wait_objects_exist (paths = paths )
1338- df2 = wr .athena .read_sql_table ("test_csv_catalog" , database )
1318+ df2 = wr .athena .read_sql_table (table , database )
13391319 assert len (df2 .index ) == 3
13401320 assert len (df2 .columns ) == 11
13411321 assert df2 ["id" ].sum () == 6
13421322 ensure_data_types_csv (df2 )
13431323 wr .s3 .delete_objects (path = paths )
1344- assert wr .catalog .delete_table_if_exists (database = database , table = "test_csv_catalog" ) is True
1324+ assert wr .catalog .delete_table_if_exists (database = database , table = table ) is True
13451325
13461326
13471327def test_csv_catalog_columns (bucket , database ):
@@ -2060,26 +2040,23 @@ def test_cache_query_ctas_approach_false(path, database, table):
20602040
20612041def test_cache_query_semicolon (path , database , table ):
20622042 df = pd .DataFrame ({"c0" : [0 , None ]}, dtype = "Int64" )
2063- paths = wr .s3 .to_parquet (
2064- df = df ,
2065- path = path ,
2066- dataset = True ,
2067- mode = "overwrite" ,
2068- database = database ,
2069- table = table ,
2070- )["paths" ]
2043+ paths = wr .s3 .to_parquet (df = df , path = path , dataset = True , mode = "overwrite" , database = database , table = table )["paths" ]
20712044 wr .s3 .wait_objects_exist (paths = paths )
20722045
20732046 with patch (
20742047 "awswrangler.athena._check_for_cached_results" , return_value = {"has_valid_cache" : False }
20752048 ) as mocked_cache_attempt :
2076- df2 = wr .athena .read_sql_query (f"SELECT * FROM { table } " , database = database , ctas_approach = True , max_cache_seconds = 0 )
2049+ df2 = wr .athena .read_sql_query (
2050+ f"SELECT * FROM { table } " , database = database , ctas_approach = True , max_cache_seconds = 0
2051+ )
20772052 mocked_cache_attempt .assert_called ()
20782053 assert df .shape == df2 .shape
20792054 assert df .c0 .sum () == df2 .c0 .sum ()
20802055
20812056 with patch ("awswrangler.athena._resolve_query_without_cache" ) as resolve_no_cache :
2082- df3 = wr .athena .read_sql_query (f"SELECT * FROM { table } ;" , database = database , ctas_approach = True , max_cache_seconds = 900 )
2057+ df3 = wr .athena .read_sql_query (
2058+ f"SELECT * FROM { table } ;" , database = database , ctas_approach = True , max_cache_seconds = 900
2059+ )
20832060 resolve_no_cache .assert_not_called ()
20842061 assert df .shape == df3 .shape
20852062 assert df .c0 .sum () == df3 .c0 .sum ()
@@ -2513,39 +2490,44 @@ def test_sanitize_columns(path, sanitize_columns, col):
25132490
25142491
25152492def test_parquet_catalog_casting_to_string (path , table , database ):
2516- paths = wr .s3 .to_parquet (
2517- df = get_df_cast (),
2518- path = path ,
2519- index = False ,
2520- dataset = True ,
2521- mode = "overwrite" ,
2522- database = database ,
2523- table = table ,
2524- dtype = {
2525- "iint8" : "string" ,
2526- "iint16" : "string" ,
2527- "iint32" : "string" ,
2528- "iint64" : "string" ,
2529- "float" : "string" ,
2530- "double" : "double" ,
2531- "decimal" : "string" ,
2532- "string" : "string" ,
2533- "date" : "string" ,
2534- "timestamp" : "string" ,
2535- "bool" : "string" ,
2536- "binary" : "string" ,
2537- "category" : "string" ,
2538- "par0" : "string" ,
2539- "par1" : "string" ,
2540- },
2541- )["paths" ]
2542- wr .s3 .wait_objects_exist (paths = paths )
2543- df = wr .s3 .read_parquet (path = path )
2544- assert len (df .index ) == 3
2545- assert len (df .columns ) == 15
2546- df = wr .athena .read_sql_table (table = table , database = database , ctas_approach = True )
2547- assert len (df .index ) == 3
2548- assert len (df .columns ) == 15
2549- df = wr .athena .read_sql_table (table = table , database = database , ctas_approach = False )
2550- assert len (df .index ) == 3
2551- assert len (df .columns ) == 15
2493+ for df in [get_df (), get_df_cast ()]:
2494+ paths = wr .s3 .to_parquet (
2495+ df = df ,
2496+ path = path ,
2497+ index = False ,
2498+ dataset = True ,
2499+ mode = "overwrite" ,
2500+ database = database ,
2501+ table = table ,
2502+ dtype = {
2503+ "iint8" : "string" ,
2504+ "iint16" : "string" ,
2505+ "iint32" : "string" ,
2506+ "iint64" : "string" ,
2507+ "float" : "string" ,
2508+ "double" : "string" ,
2509+ "decimal" : "string" ,
2510+ "string" : "string" ,
2511+ "date" : "string" ,
2512+ "timestamp" : "string" ,
2513+ "timestamp2" : "string" ,
2514+ "bool" : "string" ,
2515+ "binary" : "string" ,
2516+ "category" : "string" ,
2517+ "par0" : "string" ,
2518+ "par1" : "string" ,
2519+ },
2520+ )["paths" ]
2521+ wr .s3 .wait_objects_exist (paths = paths )
2522+ df = wr .s3 .read_parquet (path = path )
2523+ assert df .shape == (3 , 16 )
2524+ for dtype in df .dtypes .values :
2525+ assert str (dtype ) == "string"
2526+ df = wr .athena .read_sql_table (table = table , database = database , ctas_approach = True )
2527+ assert df .shape == (3 , 16 )
2528+ for dtype in df .dtypes .values :
2529+ assert str (dtype ) == "string"
2530+ df = wr .athena .read_sql_table (table = table , database = database , ctas_approach = False )
2531+ assert df .shape == (3 , 16 )
2532+ for dtype in df .dtypes .values :
2533+ assert str (dtype ) == "string"
0 commit comments