Skip to content

Commit 9b59583

Browse files
Fix unit tests
1 parent a324d4a commit 9b59583

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

bluecast/tests/test_custom_model_recipes.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,27 +132,27 @@ def test_linear_predict(regression_data, linear_model):
132132
X_train, X_test, y_train, y_test = regression_data
133133
linear_model.fit(X_train, X_test, y_train, y_test)
134134

135-
predictions = linear_model.predict(X_test)
135+
result = linear_model.predict(X_test)
136+
137+
# predict returns (preds, preds) tuple to match BaseClassMlModel interface
138+
assert isinstance(result, tuple), "Predictions should be a tuple."
139+
predictions = result[0]
136140

137-
# Check the types of the returned values
138141
assert isinstance(predictions, np.ndarray), "Predictions should be a numpy array."
139142

140-
# Check the shape of the returned values
141143
assert predictions.shape == (
142144
X_test.shape[0],
143145
), "Predictions should have the correct shape."
144146

145-
# Optionally, check the values of predictions are finite (not NaN or inf)
146147
assert np.all(np.isfinite(predictions)), "Predictions should be finite values."
147148

148149

149150
def test_linear_predict_range(regression_data, linear_model):
150151
X_train, X_test, y_train, y_test = regression_data
151152
linear_model.fit(X_train, X_test, y_train, y_test)
152153

153-
predictions = linear_model.predict(X_test)
154+
predictions, _ = linear_model.predict(X_test)
154155

155-
# Check if predictions are within a reasonable range
156156
assert np.all(
157157
predictions >= y_train.min() - 10
158158
), "Predictions should not be too low."

bluecast/tests/test_preprocessing_recipes.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ def test_no_numerical_columns():
102102
preprocessing = PreprocessingForLinearModels(num_columns=[])
103103
transformed_df, transformed_target = preprocessing.fit_transform(df, target)
104104

105-
# Since there are no numerical columns, the DataFrame should remain unchanged
106-
pd.testing.assert_frame_equal(transformed_df, df)
105+
# Since there are no numerical columns, all original columns should be present
106+
assert set(transformed_df.columns) == set(df.columns)
107+
for col in df.columns:
108+
pd.testing.assert_series_equal(
109+
transformed_df[col].reset_index(drop=True),
110+
df[col].reset_index(drop=True),
111+
check_names=False,
112+
)
107113
pd.testing.assert_series_equal(transformed_target, target)

0 commit comments

Comments
 (0)