Skip to content

Commit 568b838

Browse files
authored
Merge pull request #72 from jkoelker/fix/issue-69-model-state-reset
Add model state reset using diffusers maybe_free_model_hooks() API
2 parents d06b3f9 + 05f7718 commit 568b838

File tree

3 files changed

+61
-6
lines changed

3 files changed

+61
-6
lines changed

src/oneiro/pipelines/base.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,18 +180,34 @@ def build_result(
180180
guidance_scale=guidance_scale,
181181
)
182182

183-
def post_generate(self, **kwargs: Any) -> None: # noqa: B027
183+
def post_generate(self, **kwargs: Any) -> None:
184184
"""Post-generation cleanup hook called after generation completes.
185185
186-
Override for LoRA restore or other cleanup. This is an optional hook
187-
with a no-op default; it is intentionally not abstract so subclasses
188-
can choose whether to implement it.
186+
This base implementation resets stateful model caches using the diffusers
187+
`maybe_free_model_hooks()` API. This prevents state leakage between
188+
generations (e.g., KV cache, attention state, hook state).
189+
190+
Subclasses should call super().post_generate(**kwargs) first, then perform
191+
any additional cleanup (e.g., LoRA restore).
189192
190193
Note: The kwargs passed here have already had 'init_image' and 'strength'
191194
removed by generate(). If a subclass needs access to these values,
192195
it should save them in pre_generate() before they are consumed.
193196
"""
194-
pass
197+
self._reset_model_state()
198+
199+
def _reset_model_state(self) -> None:
200+
"""Reset stateful model caches between generations.
201+
202+
Uses the diffusers `maybe_free_model_hooks()` API to reset:
203+
- Stateful caches (KV cache, attention state)
204+
- CPU offload hooks (if model offloading is enabled)
205+
206+
This is the canonical way to reset diffusers pipeline state.
207+
"""
208+
if self.pipe is None:
209+
return
210+
self.pipe.maybe_free_model_hooks()
195211

196212
def unload(self) -> None:
197213
"""Free GPU memory."""

src/oneiro/pipelines/civitai_checkpoint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,8 @@ def build_generation_kwargs(
854854
return gen_kwargs
855855

856856
def post_generate(self, **kwargs: Any) -> None:
857-
"""Post-generation cleanup: restore static LoRAs if dynamic were used."""
857+
"""Post-generation cleanup: reset model state and restore static LoRAs."""
858+
super().post_generate(**kwargs)
858859
if self._has_dynamic_loras:
859860
self._restore_static_loras()
860861
self._has_dynamic_loras = False

tests/test_pipelines_base.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,44 @@ def test_configure_handles_none_cpu_count(
246246
assert result >= 1 # Should at least be 1
247247

248248

249+
class TestBasePipelinePostGenerate:
250+
"""Tests for BasePipeline.post_generate() and _reset_model_state()."""
251+
252+
@patch("oneiro.pipelines.base.torch.cuda.is_available", return_value=False)
253+
def test_post_generate_calls_reset_model_state(self, mock_cuda):
254+
"""post_generate() calls _reset_model_state()."""
255+
pipeline = ConcretePipeline()
256+
pipeline._reset_model_state = Mock()
257+
pipeline.post_generate()
258+
pipeline._reset_model_state.assert_called_once()
259+
260+
@patch("oneiro.pipelines.base.torch.cuda.is_available", return_value=False)
261+
def test_reset_model_state_calls_maybe_free_model_hooks(self, mock_cuda):
262+
"""_reset_model_state() calls pipe.maybe_free_model_hooks()."""
263+
pipeline = ConcretePipeline()
264+
mock_pipe = Mock()
265+
pipeline.pipe = mock_pipe
266+
pipeline._reset_model_state()
267+
mock_pipe.maybe_free_model_hooks.assert_called_once()
268+
269+
@patch("oneiro.pipelines.base.torch.cuda.is_available", return_value=False)
270+
def test_reset_model_state_handles_none_pipe(self, mock_cuda):
271+
"""_reset_model_state() handles pipe being None."""
272+
pipeline = ConcretePipeline()
273+
pipeline.pipe = None
274+
# Should not raise
275+
pipeline._reset_model_state()
276+
277+
@patch("oneiro.pipelines.base.torch.cuda.is_available", return_value=False)
278+
def test_post_generate_accepts_kwargs(self, mock_cuda):
279+
"""post_generate() accepts arbitrary kwargs."""
280+
pipeline = ConcretePipeline()
281+
pipeline._reset_model_state = Mock()
282+
# Should not raise
283+
pipeline.post_generate(some_kwarg="value", another=123)
284+
pipeline._reset_model_state.assert_called_once()
285+
286+
249287
class TestPipelineManagerLoraResolution:
250288
"""Tests for PipelineManager.generate() LoRA path resolution."""
251289

0 commit comments

Comments
 (0)