Skip to content

Commit 7c6b4f5

Browse files
authored
Merge pull request #70 from jkoelker/fix/issue-69-lora-state-leakage
fix: prevent LoRA state leakage between generation requests
2 parents 03f9362 + 1446c60 commit 7c6b4f5

File tree

7 files changed

+571
-2
lines changed

7 files changed

+571
-2
lines changed

src/oneiro/pipelines/flux2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def load(self, model_config: dict[str, Any], full_config: dict[str, Any] | None
6464
if loras:
6565
print(f" Loading {len(loras)} LoRA(s)...")
6666
self.load_loras_sync(loras)
67+
self.set_static_loras(loras)
6768

6869
# Load embeddings if full_config provided
6970
if full_config:
@@ -131,3 +132,8 @@ def build_generation_kwargs(
131132
"guidance_scale": guidance_scale,
132133
"generator": generator,
133134
}
135+
136+
def post_generate(self, **kwargs: Any) -> None:
137+
"""Reset LoRA state after generation to prevent state leakage."""
138+
super().post_generate(**kwargs)
139+
self.restore_static_loras()

src/oneiro/pipelines/lora.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -605,8 +605,9 @@ class LoraLoaderMixin:
605605
def __init__(self) -> None:
606606
"""Initialize LoRA state and continue MRO chain."""
607607
super().__init__()
608-
self._lora_configs = []
609-
self._loaded_adapters = []
608+
self._lora_configs: list[LoraConfig] = []
609+
self._loaded_adapters: list[str] = []
610+
self._static_lora_configs: list[LoraConfig] = []
610611

611612
def load_single_lora(
612613
self,
@@ -796,3 +797,42 @@ def active_loras(self) -> list[str]:
796797
def lora_count(self) -> int:
797798
"""Get number of loaded LoRAs."""
798799
return len(self._loaded_adapters)
800+
801+
def set_static_loras(self, loras: list[LoraConfig]) -> None:
802+
"""Store the static LoRA baseline for post-generation restore.
803+
804+
Call this after loading LoRAs from config in load() to establish
805+
which adapters should be preserved between generation requests.
806+
807+
Args:
808+
loras: List of LoRA configurations that represent the static baseline
809+
"""
810+
self._static_lora_configs = list(loras)
811+
812+
def restore_static_loras(self) -> None:
813+
"""Restore LoRA state to static adapters loaded from config.
814+
815+
Call this in post_generate() to prevent state leakage between requests.
816+
If dynamic LoRAs were added during a request, this will unload them
817+
and restore only the static config LoRAs.
818+
819+
Behavior:
820+
- If no static LoRAs and no loaded adapters: no-op
821+
- If adapters match static config: reset weights only
822+
- If adapters differ: unload all and reload static LoRAs
823+
"""
824+
static_names = [lora.adapter_name or lora.name for lora in self._static_lora_configs]
825+
adapters_match = self._loaded_adapters == static_names
826+
827+
if adapters_match:
828+
# Adapters match - just reset weights if there are static LoRAs
829+
if self._static_lora_configs:
830+
adapter_weights = [lora.weight for lora in self._static_lora_configs]
831+
self.set_lora_adapters(static_names, adapter_weights)
832+
return
833+
834+
# Adapters don't match - need to restore to static baseline
835+
self.unload_loras()
836+
if self._static_lora_configs:
837+
self.load_loras_sync(self._static_lora_configs)
838+
print(f"Restored {len(self._static_lora_configs)} static LoRA(s)")

src/oneiro/pipelines/qwen.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def load(self, model_config: dict[str, Any], full_config: dict[str, Any] | None
146146
if loras:
147147
print(f" Loading {len(loras)} LoRA(s)...")
148148
self.load_loras_sync(loras)
149+
self.set_static_loras(loras)
149150

150151
if full_config:
151152
embeddings = parse_embeddings_from_config(full_config, model_config)
@@ -247,3 +248,8 @@ def build_result(
247248
steps=steps,
248249
guidance_scale=guidance_scale,
249250
)
251+
252+
def post_generate(self, **kwargs: Any) -> None:
253+
"""Reset LoRA state after generation to prevent state leakage."""
254+
super().post_generate(**kwargs)
255+
self.restore_static_loras()

src/oneiro/pipelines/zimage.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def load(self, model_config: dict[str, Any], full_config: dict[str, Any] | None
3939
if loras:
4040
print(f" Loading {len(loras)} LoRA(s)...")
4141
self.load_loras_sync(loras)
42+
self.set_static_loras(loras)
4243

4344
# Load embeddings if full_config provided
4445
if full_config:
@@ -107,3 +108,8 @@ def build_result(
107108
steps=steps,
108109
guidance_scale=0.0, # Always 0.0 for Turbo
109110
)
111+
112+
def post_generate(self, **kwargs: Any) -> None:
113+
"""Reset LoRA state after generation to prevent state leakage."""
114+
super().post_generate(**kwargs)
115+
self.restore_static_loras()

tests/test_lora.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,3 +798,153 @@ async def test_civitai_handles_empty_trained_words(self, tmp_path):
798798
)
799799

800800
assert config.trigger_words == []
801+
802+
803+
class TestLoraLoaderMixinStaticLoras:
804+
"""Tests for LoraLoaderMixin static LoRA management methods."""
805+
806+
def test_set_static_loras_stores_configs(self):
807+
"""set_static_loras stores a copy of the configs."""
808+
from oneiro.pipelines.lora import LoraLoaderMixin
809+
810+
class TestMixin(LoraLoaderMixin):
811+
pass
812+
813+
mixin = TestMixin()
814+
loras = [
815+
LoraConfig(name="lora1", source=LoraSource.LOCAL, path="/path/1"),
816+
LoraConfig(name="lora2", source=LoraSource.LOCAL, path="/path/2"),
817+
]
818+
819+
mixin.set_static_loras(loras)
820+
821+
assert mixin._static_lora_configs == loras
822+
# Verify it's a copy, not the same list
823+
assert mixin._static_lora_configs is not loras
824+
825+
def test_set_static_loras_empty_list(self):
826+
"""set_static_loras handles empty list."""
827+
from oneiro.pipelines.lora import LoraLoaderMixin
828+
829+
class TestMixin(LoraLoaderMixin):
830+
pass
831+
832+
mixin = TestMixin()
833+
mixin.set_static_loras([])
834+
835+
assert mixin._static_lora_configs == []
836+
837+
def test_restore_static_loras_no_op_when_empty(self):
838+
"""restore_static_loras is no-op when no static loras and no loaded adapters."""
839+
from oneiro.pipelines.lora import LoraLoaderMixin
840+
841+
class TestMixin(LoraLoaderMixin):
842+
pass
843+
844+
mixin = TestMixin()
845+
mixin.unload_loras = Mock()
846+
mixin.load_loras_sync = Mock()
847+
mixin.set_lora_adapters = Mock()
848+
849+
mixin.restore_static_loras()
850+
851+
mixin.unload_loras.assert_not_called()
852+
mixin.load_loras_sync.assert_not_called()
853+
mixin.set_lora_adapters.assert_not_called()
854+
855+
def test_restore_static_loras_resets_weights_when_adapters_match(self):
856+
"""restore_static_loras only resets weights when adapters match static config."""
857+
from oneiro.pipelines.lora import LoraLoaderMixin
858+
859+
class TestMixin(LoraLoaderMixin):
860+
pass
861+
862+
mixin = TestMixin()
863+
loras = [
864+
LoraConfig(name="lora1", source=LoraSource.LOCAL, path="/path/1", weight=0.8),
865+
LoraConfig(name="lora2", source=LoraSource.LOCAL, path="/path/2", weight=0.5),
866+
]
867+
mixin.set_static_loras(loras)
868+
mixin._loaded_adapters = ["lora1", "lora2"]
869+
870+
mixin.unload_loras = Mock()
871+
mixin.load_loras_sync = Mock()
872+
mixin.set_lora_adapters = Mock()
873+
874+
mixin.restore_static_loras()
875+
876+
mixin.unload_loras.assert_not_called()
877+
mixin.load_loras_sync.assert_not_called()
878+
mixin.set_lora_adapters.assert_called_once_with(["lora1", "lora2"], [0.8, 0.5])
879+
880+
def test_restore_static_loras_reloads_when_adapters_differ(self):
881+
"""restore_static_loras reloads when loaded adapters differ from static."""
882+
from oneiro.pipelines.lora import LoraLoaderMixin
883+
884+
class TestMixin(LoraLoaderMixin):
885+
pass
886+
887+
mixin = TestMixin()
888+
static_loras = [
889+
LoraConfig(name="static-lora", source=LoraSource.LOCAL, path="/path/1"),
890+
]
891+
mixin.set_static_loras(static_loras)
892+
# Simulate dynamic lora was added
893+
mixin._loaded_adapters = ["static-lora", "dynamic-lora"]
894+
895+
mixin.unload_loras = Mock()
896+
mixin.load_loras_sync = Mock()
897+
mixin.set_lora_adapters = Mock()
898+
899+
mixin.restore_static_loras()
900+
901+
mixin.unload_loras.assert_called_once()
902+
mixin.load_loras_sync.assert_called_once_with(static_loras)
903+
mixin.set_lora_adapters.assert_not_called()
904+
905+
def test_restore_static_loras_unloads_only_when_no_static(self):
906+
"""restore_static_loras unloads all when no static but has loaded adapters."""
907+
from oneiro.pipelines.lora import LoraLoaderMixin
908+
909+
class TestMixin(LoraLoaderMixin):
910+
pass
911+
912+
mixin = TestMixin()
913+
mixin._loaded_adapters = ["dynamic-lora"]
914+
915+
mixin.unload_loras = Mock()
916+
mixin.load_loras_sync = Mock()
917+
mixin.set_lora_adapters = Mock()
918+
919+
mixin.restore_static_loras()
920+
921+
mixin.unload_loras.assert_called_once()
922+
mixin.load_loras_sync.assert_not_called()
923+
924+
def test_restore_static_loras_uses_adapter_name_when_set(self):
925+
"""restore_static_loras uses adapter_name if specified, otherwise name."""
926+
from oneiro.pipelines.lora import LoraLoaderMixin
927+
928+
class TestMixin(LoraLoaderMixin):
929+
pass
930+
931+
mixin = TestMixin()
932+
loras = [
933+
LoraConfig(
934+
name="lora1",
935+
source=LoraSource.LOCAL,
936+
path="/path/1",
937+
adapter_name="custom_adapter",
938+
weight=0.7,
939+
),
940+
]
941+
mixin.set_static_loras(loras)
942+
mixin._loaded_adapters = ["custom_adapter"]
943+
944+
mixin.unload_loras = Mock()
945+
mixin.load_loras_sync = Mock()
946+
mixin.set_lora_adapters = Mock()
947+
948+
mixin.restore_static_loras()
949+
950+
mixin.set_lora_adapters.assert_called_once_with(["custom_adapter"], [0.7])

0 commit comments

Comments
 (0)