Skip to content

Commit 49d4189

Browse files
authored
Merge pull request #58 from jkoelker/jk/lora-fix
fix: resolve device mismatch when loading auto-detected LoRAs
2 parents 2cc2211 + b8c87b4 commit 49d4189

File tree

4 files changed

+391
-5
lines changed

4 files changed

+391
-5
lines changed

src/oneiro/pipelines/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,30 @@ async def generate(
152152
if self.pipeline is None:
153153
raise RuntimeError("No pipeline loaded")
154154

155+
loras: list[LoraConfig] | None = kwargs.pop("loras", None)
156+
if loras:
157+
pipeline_type = None
158+
if self.current_model:
159+
model_config = self.config.get("models", self.current_model)
160+
if model_config:
161+
pipeline_type = model_config.get("type")
162+
163+
resolved_loras: list[LoraConfig] = []
164+
for lora in loras:
165+
try:
166+
await resolve_lora_path(
167+
lora,
168+
civitai_client=self._civitai_client,
169+
pipeline_type=pipeline_type,
170+
validate_compatibility=True,
171+
)
172+
resolved_loras.append(lora)
173+
except Exception as e:
174+
print(f"Warning: Failed to resolve LoRA {lora.name}: {e}")
175+
176+
if resolved_loras:
177+
kwargs["loras"] = resolved_loras
178+
155179
return await asyncio.to_thread(
156180
self.pipeline.generate,
157181
prompt,

src/oneiro/pipelines/civitai_checkpoint.py

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
get_weighted_text_embeddings_sd15,
2121
get_weighted_text_embeddings_sdxl,
2222
)
23-
from oneiro.pipelines.lora import LoraLoaderMixin, parse_loras_from_model_config
23+
from oneiro.pipelines.lora import LoraConfig, LoraLoaderMixin, parse_loras_from_model_config
2424

2525
if TYPE_CHECKING:
2626
from oneiro.civitai import CivitaiClient, ModelVersion
@@ -518,6 +518,8 @@ def __init__(self) -> None:
518518
self._base_model: str | None = None
519519
self._full_config: dict[str, Any] | None = None
520520
self._current_scheduler: str | None = None
521+
self._static_lora_configs: list[LoraConfig] = []
522+
self._cpu_offload: bool = False
521523

522524
def load(self, model_config: dict[str, Any], full_config: dict[str, Any] | None = None) -> None:
523525
"""Load checkpoint from config (synchronous, requires checkpoint_path).
@@ -631,7 +633,8 @@ def _load_from_path(self, checkpoint_path: Path, model_config: dict[str, Any]) -
631633

632634
# Apply optimizations
633635
cpu_offload = model_config.get("cpu_offload", True)
634-
if cpu_offload and self._device == "cuda":
636+
self._cpu_offload = cpu_offload and self._device == "cuda"
637+
if self._cpu_offload:
635638
self.pipe.enable_model_cpu_offload()
636639
elif self._device == "cuda":
637640
self.pipe.to("cuda")
@@ -643,11 +646,13 @@ def _load_from_path(self, checkpoint_path: Path, model_config: dict[str, Any]) -
643646
if hasattr(self.pipe.vae, "enable_slicing"):
644647
self.pipe.vae.enable_slicing()
645648

646-
# Load LoRAs if configured
647649
loras = parse_loras_from_model_config(model_config)
648650
if loras:
649651
print(f" Loading {len(loras)} LoRA(s)...")
650652
self.load_loras_sync(loras)
653+
self._static_lora_configs = list(loras)
654+
else:
655+
self._static_lora_configs = []
651656

652657
# Load embeddings if full_config provided
653658
if self._full_config:
@@ -743,6 +748,8 @@ def generate(
743748
pass the resulting image to the pipeline.
744749
- ``strength``: Strength parameter for img2img generation, used
745750
when ``init_image`` is provided. Defaults to ``0.75``.
751+
- ``loras``: List of LoraConfig objects for dynamic LoRA loading.
752+
These are loaded before generation and unloaded after.
746753
747754
Any other keyword arguments are passed through unchanged to the
748755
underlying diffusers pipeline call and may be used to control
@@ -765,7 +772,50 @@ def generate(
765772
if scheduler_override:
766773
self.configure_scheduler(scheduler_override)
767774

768-
# Use defaults from pipeline config
775+
dynamic_loras = kwargs.pop("loras", None)
776+
has_dynamic_loras = False
777+
if dynamic_loras:
778+
# Mark that we are entering a dynamic LoRA context before loading,
779+
# so that failures during loading can be properly rolled back.
780+
has_dynamic_loras = True
781+
try:
782+
self._load_dynamic_loras(dynamic_loras)
783+
except Exception:
784+
# If loading dynamic LoRAs fails after modifying the pipeline
785+
# state (for example, after unloading static LoRAs), attempt
786+
# to restore the original static LoRAs before propagating
787+
# the error.
788+
self._restore_static_loras()
789+
has_dynamic_loras = False
790+
raise
791+
792+
try:
793+
return self._run_generation(
794+
prompt=prompt,
795+
negative_prompt=negative_prompt,
796+
width=width,
797+
height=height,
798+
seed=seed,
799+
steps=steps,
800+
guidance_scale=guidance_scale,
801+
**kwargs,
802+
)
803+
finally:
804+
if has_dynamic_loras:
805+
self._restore_static_loras()
806+
807+
def _run_generation(
808+
self,
809+
prompt: str,
810+
negative_prompt: str | None,
811+
width: int | None,
812+
height: int | None,
813+
seed: int,
814+
steps: int | None,
815+
guidance_scale: float | None,
816+
**kwargs: Any,
817+
) -> GenerationResult:
818+
assert self._pipeline_config is not None
769819
width = width or self._pipeline_config.default_width
770820
height = height or self._pipeline_config.default_height
771821
steps = steps or self._pipeline_config.default_steps
@@ -824,6 +874,47 @@ def generate(
824874
guidance_scale=guidance_scale,
825875
)
826876

877+
def _load_dynamic_loras(self, loras: list[LoraConfig]) -> None:
878+
if self.pipe is None or not loras:
879+
return
880+
881+
lora_configs: list[LoraConfig] = [cfg for cfg in loras if isinstance(cfg, LoraConfig)]
882+
if not lora_configs:
883+
return
884+
885+
self.unload_loras()
886+
887+
# Only move pipeline to device manually when CPU offload is not enabled.
888+
# With CPU offload, diffusers manages device placement automatically.
889+
if not self._cpu_offload:
890+
self.pipe.to(self._device)
891+
892+
loaded_names: list[str] = []
893+
loaded_weights: list[float] = []
894+
895+
for lora in lora_configs:
896+
try:
897+
name = self.load_single_lora(lora)
898+
loaded_names.append(name)
899+
loaded_weights.append(lora.weight)
900+
print(f"Loaded dynamic LoRA: {lora.name} (weight={lora.weight})")
901+
except Exception as e:
902+
print(f"Warning: Failed to load LoRA {lora.name}: {e}")
903+
904+
if loaded_names:
905+
self.set_lora_adapters(loaded_names, loaded_weights)
906+
907+
def _restore_static_loras(self) -> None:
908+
self.unload_loras()
909+
910+
if not self._static_lora_configs:
911+
return
912+
913+
if not self._cpu_offload:
914+
self.pipe.to(self._device)
915+
self.load_loras_sync(self._static_lora_configs)
916+
print(f"Restored {len(self._static_lora_configs)} static LoRA(s)")
917+
827918
@property
828919
def pipeline_config(self) -> PipelineConfig | None:
829920
"""Get the current pipeline configuration."""

tests/test_civitai_checkpoint.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for CivitAI checkpoint pipeline."""
22

3+
from pathlib import Path
34
from unittest.mock import AsyncMock, MagicMock, patch
45

56
import pytest
@@ -14,6 +15,7 @@
1415
get_diffusers_pipeline_class,
1516
get_pipeline_config_for_base_model,
1617
)
18+
from oneiro.pipelines.lora import LoraConfig, LoraSource
1719

1820

1921
class TestPipelineConfig:
@@ -1156,3 +1158,166 @@ def test_falls_back_to_sd15_for_unknown_pipeline(self):
11561158
mock_func.assert_called_once()
11571159
assert result["prompt_embeds"] is mock_prompt
11581160
assert result["negative_prompt_embeds"] is mock_neg_prompt
1161+
1162+
1163+
class TestDynamicLoraGeneration:
1164+
"""Tests for dynamic LoRA loading during generation."""
1165+
1166+
def _create_pipeline_with_mocks(self):
1167+
"""Create a pipeline with common mocks for dynamic LoRA tests."""
1168+
pipeline = CivitaiCheckpointPipeline()
1169+
pipeline._pipeline_config = PipelineConfig(
1170+
pipeline_class="StableDiffusionXLPipeline",
1171+
default_steps=25,
1172+
default_guidance_scale=7.0,
1173+
default_width=1024,
1174+
default_height=1024,
1175+
)
1176+
mock_pipe = MagicMock()
1177+
mock_image = MagicMock()
1178+
mock_image.width = 1024
1179+
mock_image.height = 1024
1180+
mock_pipe.return_value.images = [mock_image]
1181+
pipeline.pipe = mock_pipe
1182+
pipeline._cpu_offload = False
1183+
return pipeline
1184+
1185+
def test_generate_with_dynamic_loras(self):
1186+
"""generate() loads dynamic LoRAs passed via kwargs."""
1187+
pipeline = self._create_pipeline_with_mocks()
1188+
1189+
lora = LoraConfig(name="test-lora", source=LoraSource.LOCAL, path="/fake/path.safetensors")
1190+
lora._resolved_path = Path("/fake/path.safetensors")
1191+
1192+
with (
1193+
patch("oneiro.pipelines.civitai_checkpoint.torch"),
1194+
patch.object(pipeline, "_encode_prompts_to_embeddings"),
1195+
patch.object(pipeline, "_load_dynamic_loras") as mock_load,
1196+
patch.object(pipeline, "_restore_static_loras") as mock_restore,
1197+
):
1198+
pipeline.generate("test prompt", loras=[lora])
1199+
1200+
mock_load.assert_called_once_with([lora])
1201+
mock_restore.assert_called_once()
1202+
1203+
def test_generate_restores_static_loras_after_dynamic(self):
1204+
"""generate() restores static LoRAs after using dynamic ones."""
1205+
pipeline = self._create_pipeline_with_mocks()
1206+
1207+
static_lora = LoraConfig(
1208+
name="static-lora", source=LoraSource.LOCAL, path="/static.safetensors"
1209+
)
1210+
pipeline._static_lora_configs = [static_lora]
1211+
1212+
dynamic_lora = LoraConfig(
1213+
name="dynamic-lora", source=LoraSource.LOCAL, path="/dynamic.safetensors"
1214+
)
1215+
dynamic_lora._resolved_path = Path("/dynamic.safetensors")
1216+
1217+
with (
1218+
patch("oneiro.pipelines.civitai_checkpoint.torch"),
1219+
patch.object(pipeline, "_encode_prompts_to_embeddings"),
1220+
patch.object(pipeline, "unload_loras") as mock_unload,
1221+
patch.object(pipeline, "load_single_lora", return_value="dynamic-lora"),
1222+
patch.object(pipeline, "set_lora_adapters"),
1223+
patch.object(pipeline, "load_loras_sync") as mock_load_sync,
1224+
):
1225+
pipeline.generate("test prompt", loras=[dynamic_lora])
1226+
1227+
assert mock_unload.call_count == 2
1228+
mock_load_sync.assert_called_once_with([static_lora])
1229+
1230+
def test_generate_handles_dynamic_lora_loading_failure(self):
1231+
"""generate() restores static LoRAs when dynamic loading fails."""
1232+
pipeline = self._create_pipeline_with_mocks()
1233+
1234+
static_lora = LoraConfig(
1235+
name="static-lora", source=LoraSource.LOCAL, path="/static.safetensors"
1236+
)
1237+
pipeline._static_lora_configs = [static_lora]
1238+
1239+
dynamic_lora = LoraConfig(name="bad-lora", source=LoraSource.LOCAL, path="/bad.safetensors")
1240+
dynamic_lora._resolved_path = Path("/bad.safetensors")
1241+
1242+
with (
1243+
patch("oneiro.pipelines.civitai_checkpoint.torch"),
1244+
patch.object(pipeline, "_encode_prompts_to_embeddings"),
1245+
patch.object(pipeline, "_load_dynamic_loras", side_effect=RuntimeError("Load failed")),
1246+
patch.object(pipeline, "_restore_static_loras") as mock_restore,
1247+
):
1248+
with pytest.raises(RuntimeError, match="Load failed"):
1249+
pipeline.generate("test prompt", loras=[dynamic_lora])
1250+
1251+
mock_restore.assert_called_once()
1252+
1253+
def test_generate_cleanup_on_generation_failure(self):
1254+
"""generate() cleans up dynamic LoRAs even if generation fails."""
1255+
pipeline = self._create_pipeline_with_mocks()
1256+
1257+
dynamic_lora = LoraConfig(
1258+
name="dynamic-lora", source=LoraSource.LOCAL, path="/dynamic.safetensors"
1259+
)
1260+
dynamic_lora._resolved_path = Path("/dynamic.safetensors")
1261+
1262+
pipeline.pipe.side_effect = RuntimeError("Generation failed")
1263+
1264+
with (
1265+
patch("oneiro.pipelines.civitai_checkpoint.torch"),
1266+
patch.object(pipeline, "_encode_prompts_to_embeddings"),
1267+
patch.object(pipeline, "_load_dynamic_loras"),
1268+
patch.object(pipeline, "_restore_static_loras") as mock_restore,
1269+
):
1270+
with pytest.raises(RuntimeError, match="Generation failed"):
1271+
pipeline.generate("test prompt", loras=[dynamic_lora])
1272+
1273+
mock_restore.assert_called_once()
1274+
1275+
def test_generate_without_dynamic_loras_skips_lora_handling(self):
1276+
"""generate() skips LoRA handling when no dynamic LoRAs provided."""
1277+
pipeline = self._create_pipeline_with_mocks()
1278+
1279+
with (
1280+
patch("oneiro.pipelines.civitai_checkpoint.torch"),
1281+
patch.object(pipeline, "_encode_prompts_to_embeddings"),
1282+
patch.object(pipeline, "_load_dynamic_loras") as mock_load,
1283+
patch.object(pipeline, "_restore_static_loras") as mock_restore,
1284+
):
1285+
pipeline.generate("test prompt")
1286+
1287+
mock_load.assert_not_called()
1288+
mock_restore.assert_not_called()
1289+
1290+
def test_load_dynamic_loras_respects_cpu_offload(self):
1291+
"""_load_dynamic_loras() skips .to(device) when cpu_offload enabled."""
1292+
pipeline = self._create_pipeline_with_mocks()
1293+
pipeline._cpu_offload = True
1294+
1295+
lora = LoraConfig(name="test-lora", source=LoraSource.LOCAL, path="/fake.safetensors")
1296+
lora._resolved_path = Path("/fake.safetensors")
1297+
1298+
with (
1299+
patch.object(pipeline, "unload_loras"),
1300+
patch.object(pipeline, "load_single_lora", return_value="test-lora"),
1301+
patch.object(pipeline, "set_lora_adapters"),
1302+
):
1303+
pipeline._load_dynamic_loras([lora])
1304+
1305+
pipeline.pipe.to.assert_not_called()
1306+
1307+
def test_load_dynamic_loras_moves_to_device_without_cpu_offload(self):
1308+
"""_load_dynamic_loras() calls .to(device) when cpu_offload disabled."""
1309+
pipeline = self._create_pipeline_with_mocks()
1310+
pipeline._cpu_offload = False
1311+
pipeline._device = "cuda"
1312+
1313+
lora = LoraConfig(name="test-lora", source=LoraSource.LOCAL, path="/fake.safetensors")
1314+
lora._resolved_path = Path("/fake.safetensors")
1315+
1316+
with (
1317+
patch.object(pipeline, "unload_loras"),
1318+
patch.object(pipeline, "load_single_lora", return_value="test-lora"),
1319+
patch.object(pipeline, "set_lora_adapters"),
1320+
):
1321+
pipeline._load_dynamic_loras([lora])
1322+
1323+
pipeline.pipe.to.assert_called_once_with("cuda")

0 commit comments

Comments
 (0)