@@ -81,20 +81,33 @@ def test_preprocess_train_predict_assertions(sample_dataframes, sample_column_ty
8181 # Tests that the assertions within preprocess_train_predict fire correctly.
8282
8383 train_df , test_df = sample_dataframes
84+ target_col = "num_col_1"
8485
85- # Test mismatching columns
86- test_df_mismatch = test_df .drop (columns = ["num_col_1" ])
87- with pytest .raises (AssertionError , match = "Columns in df_train and df_test do not match" ):
88- preprocess_train_predict (train_df , test_df_mismatch , "cat_col_1" , sample_column_types )
86+ # 1. Target column not in train_points
87+ with pytest .raises (AssertionError , match = "Target column 'non_existent_col' not found in train_points." ):
88+ preprocess_train_predict (train_df , test_df , "non_existent_col" , sample_column_types )
8989
90- # Test target_col not in column_types
91- with pytest .raises (AssertionError , match = "must appear exactly once" ):
92- preprocess_train_predict (train_df , test_df , "missing_col" , sample_column_types )
90+ # 2. Target column not in test_points
91+ test_df_missing_target = test_df .drop (columns = [target_col ])
92+ with pytest .raises (AssertionError , match = f"Target column '{ target_col } ' not found in test_points." ):
93+ preprocess_train_predict (train_df , test_df_missing_target , target_col , sample_column_types )
9394
94- # Test column_types not matching dataframe columns
95- bad_column_types = {"numerical" : ["num_col_1" ], "categorical" : []}
96- with pytest .raises (AssertionError , match = "must match the columns in the combined dataframe" ):
97- preprocess_train_predict (train_df , test_df , "num_col_1" , bad_column_types )
95+ # 3. Mismatched columns between train and test dataframes
96+ test_df_mismatched = test_df .rename (columns = {"num_col_2" : "new_col_name" })
97+ with pytest .raises (AssertionError , match = "Columns in df_train and df_test do not match" ):
98+ preprocess_train_predict (train_df , test_df_mismatched , target_col , sample_column_types )
99+
100+ # 4. Target column appears more than once in column_types
101+ column_types_duplicate = sample_column_types .copy ()
102+ column_types_duplicate ["categorical" ] = column_types_duplicate ["categorical" ] + [target_col ]
103+ with pytest .raises (AssertionError , match = f"The target column '{ target_col } ' must appear exactly once" ):
104+ preprocess_train_predict (train_df , test_df , target_col , column_types_duplicate )
105+
106+ # 5. Mismatch between dataframe columns and column_types
107+ column_types_mismatch = sample_column_types .copy ()
108+ column_types_mismatch ["numerical" ] = ["num_col_1" ] # Missing num_col_2
109+ with pytest .raises (AssertionError , match = "The union of numeric_columns and categorical_columns must match" ):
110+ preprocess_train_predict (train_df , test_df , target_col , column_types_mismatch )
98111
99112
100113def test_main_feature_extraction (sample_dataframes , sample_column_types ):
0 commit comments