@@ -132,27 +132,27 @@ def test_linear_predict(regression_data, linear_model):
132132 X_train , X_test , y_train , y_test = regression_data
133133 linear_model .fit (X_train , X_test , y_train , y_test )
134134
135- predictions = linear_model .predict (X_test )
135+ result = linear_model .predict (X_test )
136+
137+ # predict returns (preds, preds) tuple to match BaseClassMlModel interface
138+ assert isinstance (result , tuple ), "Predictions should be a tuple."
139+ predictions = result [0 ]
136140
137- # Check the types of the returned values
138141 assert isinstance (predictions , np .ndarray ), "Predictions should be a numpy array."
139142
140- # Check the shape of the returned values
141143 assert predictions .shape == (
142144 X_test .shape [0 ],
143145 ), "Predictions should have the correct shape."
144146
145- # Optionally, check the values of predictions are finite (not NaN or inf)
146147 assert np .all (np .isfinite (predictions )), "Predictions should be finite values."
147148
148149
149150def test_linear_predict_range (regression_data , linear_model ):
150151 X_train , X_test , y_train , y_test = regression_data
151152 linear_model .fit (X_train , X_test , y_train , y_test )
152153
153- predictions = linear_model .predict (X_test )
154+ predictions , _ = linear_model .predict (X_test )
154155
155- # Check if predictions are within a reasonable range
156156 assert np .all (
157157 predictions >= y_train .min () - 10
158158 ), "Predictions should not be too low."
0 commit comments