@@ -181,66 +181,6 @@ def test_streaming_history_on_concept_drift(
181181 break
182182
183183
184- def _fit_model (model , X , y ): # noqa: N803
185- model .fit (X = X , y = y )
186- return model
187-
188-
189- @pytest .mark .parametrize (
190- "detector_class" ,
191- [
192- DDM ,
193- ECDDWT ,
194- EDDM ,
195- HDDMA ,
196- HDDMW ,
197- RDDM ,
198- ], # pylint: disable=too-many-locals
199- )
200- def test_streaming_warning_samples_buffer_on_concept_drift (
201- dataset_simple : Tuple [Tuple [np .ndarray , np .ndarray ], Tuple [np .ndarray , np .ndarray ]],
202- model : sklearn .pipeline .Pipeline ,
203- detector_class : BaseSPC ,
204- ):
205- """Test streaming warning samples buffer on concept drift callback.
206-
207- :param dataset_simple: dataset with concept drift
208- :type dataset_simple: Tuple[Tuple[numpy.ndarray, numpy.ndarray],
209- :param model: trained model
210- :type model: sklearn.pipeline.Pipeline
211- :param detector_class: concept drift detector
212- :type detector_class: BaseSPC
213- """
214- _ , test = dataset_simple # noqa: N806
215-
216- detector = detector_class (
217- callbacks = WarningSamplesBuffer (name = "samples" ), # type: ignore
218- )
219-
220- collect_example_warning_samples = False
221- X_extra , y_extra = [], [] # noqa: N806
222-
223- for X , y in zip (* test ): # noqa: N806
224- y_pred = model .predict (X .reshape (1 , - 1 ))
225- if not collect_example_warning_samples :
226- error = 1 - int (y_pred == y )
227- callbacks_logs = detector .update (value = error , X = X , y = y )
228- else :
229- X_extra .append (X )
230- y_extra .append (y )
231- if detector .status ["drift" ]:
232- y_new_ref = callbacks_logs ["samples" ]["y" ] + y_extra
233- if len ([* set (y_new_ref )]) < 2 :
234- collect_example_warning_samples = True
235- else :
236- X_new_ref = callbacks_logs ["samples" ]["X" ] + X_extra # noqa: N806
237- collect_example_warning_samples = False
238- X_extra .clear ()
239- y_extra .clear ()
240- detector .reset ()
241- model = _fit_model (model = model , X = X_new_ref , y = y_new_ref )
242-
243-
244184@pytest .mark .parametrize (
245185 "detector_class,"
246186 " expected_drift_idx,"
0 commit comments