@@ -246,6 +246,44 @@ def test_configure_handles_none_cpu_count(
246246 assert result >= 1 # Should at least be 1
247247
248248
249+ class TestBasePipelinePostGenerate :
250+ """Tests for BasePipeline.post_generate() and _reset_model_state()."""
251+
252+ @patch ("oneiro.pipelines.base.torch.cuda.is_available" , return_value = False )
253+ def test_post_generate_calls_reset_model_state (self , mock_cuda ):
254+ """post_generate() calls _reset_model_state()."""
255+ pipeline = ConcretePipeline ()
256+ pipeline ._reset_model_state = Mock ()
257+ pipeline .post_generate ()
258+ pipeline ._reset_model_state .assert_called_once ()
259+
260+ @patch ("oneiro.pipelines.base.torch.cuda.is_available" , return_value = False )
261+ def test_reset_model_state_calls_maybe_free_model_hooks (self , mock_cuda ):
262+ """_reset_model_state() calls pipe.maybe_free_model_hooks()."""
263+ pipeline = ConcretePipeline ()
264+ mock_pipe = Mock ()
265+ pipeline .pipe = mock_pipe
266+ pipeline ._reset_model_state ()
267+ mock_pipe .maybe_free_model_hooks .assert_called_once ()
268+
269+ @patch ("oneiro.pipelines.base.torch.cuda.is_available" , return_value = False )
270+ def test_reset_model_state_handles_none_pipe (self , mock_cuda ):
271+ """_reset_model_state() handles pipe being None."""
272+ pipeline = ConcretePipeline ()
273+ pipeline .pipe = None
274+ # Should not raise
275+ pipeline ._reset_model_state ()
276+
277+ @patch ("oneiro.pipelines.base.torch.cuda.is_available" , return_value = False )
278+ def test_post_generate_accepts_kwargs (self , mock_cuda ):
279+ """post_generate() accepts arbitrary kwargs."""
280+ pipeline = ConcretePipeline ()
281+ pipeline ._reset_model_state = Mock ()
282+ # Should not raise
283+ pipeline .post_generate (some_kwarg = "value" , another = 123 )
284+ pipeline ._reset_model_state .assert_called_once ()
285+
286+
249287class TestPipelineManagerLoraResolution :
250288 """Tests for PipelineManager.generate() LoRA path resolution."""
251289
0 commit comments