|
1 | 1 | """Tests for CivitAI checkpoint pipeline.""" |
2 | 2 |
|
| 3 | +from pathlib import Path |
3 | 4 | from unittest.mock import AsyncMock, MagicMock, patch |
4 | 5 |
|
5 | 6 | import pytest |
|
14 | 15 | get_diffusers_pipeline_class, |
15 | 16 | get_pipeline_config_for_base_model, |
16 | 17 | ) |
| 18 | +from oneiro.pipelines.lora import LoraConfig, LoraSource |
17 | 19 |
|
18 | 20 |
|
19 | 21 | class TestPipelineConfig: |
@@ -1156,3 +1158,166 @@ def test_falls_back_to_sd15_for_unknown_pipeline(self): |
1156 | 1158 | mock_func.assert_called_once() |
1157 | 1159 | assert result["prompt_embeds"] is mock_prompt |
1158 | 1160 | assert result["negative_prompt_embeds"] is mock_neg_prompt |
| 1161 | + |
| 1162 | + |
| 1163 | +class TestDynamicLoraGeneration: |
| 1164 | + """Tests for dynamic LoRA loading during generation.""" |
| 1165 | + |
| 1166 | + def _create_pipeline_with_mocks(self): |
| 1167 | + """Create a pipeline with common mocks for dynamic LoRA tests.""" |
| 1168 | + pipeline = CivitaiCheckpointPipeline() |
| 1169 | + pipeline._pipeline_config = PipelineConfig( |
| 1170 | + pipeline_class="StableDiffusionXLPipeline", |
| 1171 | + default_steps=25, |
| 1172 | + default_guidance_scale=7.0, |
| 1173 | + default_width=1024, |
| 1174 | + default_height=1024, |
| 1175 | + ) |
| 1176 | + mock_pipe = MagicMock() |
| 1177 | + mock_image = MagicMock() |
| 1178 | + mock_image.width = 1024 |
| 1179 | + mock_image.height = 1024 |
| 1180 | + mock_pipe.return_value.images = [mock_image] |
| 1181 | + pipeline.pipe = mock_pipe |
| 1182 | + pipeline._cpu_offload = False |
| 1183 | + return pipeline |
| 1184 | + |
| 1185 | + def test_generate_with_dynamic_loras(self): |
| 1186 | + """generate() loads dynamic LoRAs passed via kwargs.""" |
| 1187 | + pipeline = self._create_pipeline_with_mocks() |
| 1188 | + |
| 1189 | + lora = LoraConfig(name="test-lora", source=LoraSource.LOCAL, path="/fake/path.safetensors") |
| 1190 | + lora._resolved_path = Path("/fake/path.safetensors") |
| 1191 | + |
| 1192 | + with ( |
| 1193 | + patch("oneiro.pipelines.civitai_checkpoint.torch"), |
| 1194 | + patch.object(pipeline, "_encode_prompts_to_embeddings"), |
| 1195 | + patch.object(pipeline, "_load_dynamic_loras") as mock_load, |
| 1196 | + patch.object(pipeline, "_restore_static_loras") as mock_restore, |
| 1197 | + ): |
| 1198 | + pipeline.generate("test prompt", loras=[lora]) |
| 1199 | + |
| 1200 | + mock_load.assert_called_once_with([lora]) |
| 1201 | + mock_restore.assert_called_once() |
| 1202 | + |
| 1203 | + def test_generate_restores_static_loras_after_dynamic(self): |
| 1204 | + """generate() restores static LoRAs after using dynamic ones.""" |
| 1205 | + pipeline = self._create_pipeline_with_mocks() |
| 1206 | + |
| 1207 | + static_lora = LoraConfig( |
| 1208 | + name="static-lora", source=LoraSource.LOCAL, path="/static.safetensors" |
| 1209 | + ) |
| 1210 | + pipeline._static_lora_configs = [static_lora] |
| 1211 | + |
| 1212 | + dynamic_lora = LoraConfig( |
| 1213 | + name="dynamic-lora", source=LoraSource.LOCAL, path="/dynamic.safetensors" |
| 1214 | + ) |
| 1215 | + dynamic_lora._resolved_path = Path("/dynamic.safetensors") |
| 1216 | + |
| 1217 | + with ( |
| 1218 | + patch("oneiro.pipelines.civitai_checkpoint.torch"), |
| 1219 | + patch.object(pipeline, "_encode_prompts_to_embeddings"), |
| 1220 | + patch.object(pipeline, "unload_loras") as mock_unload, |
| 1221 | + patch.object(pipeline, "load_single_lora", return_value="dynamic-lora"), |
| 1222 | + patch.object(pipeline, "set_lora_adapters"), |
| 1223 | + patch.object(pipeline, "load_loras_sync") as mock_load_sync, |
| 1224 | + ): |
| 1225 | + pipeline.generate("test prompt", loras=[dynamic_lora]) |
| 1226 | + |
| 1227 | + assert mock_unload.call_count == 2 |
| 1228 | + mock_load_sync.assert_called_once_with([static_lora]) |
| 1229 | + |
| 1230 | + def test_generate_handles_dynamic_lora_loading_failure(self): |
| 1231 | + """generate() restores static LoRAs when dynamic loading fails.""" |
| 1232 | + pipeline = self._create_pipeline_with_mocks() |
| 1233 | + |
| 1234 | + static_lora = LoraConfig( |
| 1235 | + name="static-lora", source=LoraSource.LOCAL, path="/static.safetensors" |
| 1236 | + ) |
| 1237 | + pipeline._static_lora_configs = [static_lora] |
| 1238 | + |
| 1239 | + dynamic_lora = LoraConfig(name="bad-lora", source=LoraSource.LOCAL, path="/bad.safetensors") |
| 1240 | + dynamic_lora._resolved_path = Path("/bad.safetensors") |
| 1241 | + |
| 1242 | + with ( |
| 1243 | + patch("oneiro.pipelines.civitai_checkpoint.torch"), |
| 1244 | + patch.object(pipeline, "_encode_prompts_to_embeddings"), |
| 1245 | + patch.object(pipeline, "_load_dynamic_loras", side_effect=RuntimeError("Load failed")), |
| 1246 | + patch.object(pipeline, "_restore_static_loras") as mock_restore, |
| 1247 | + ): |
| 1248 | + with pytest.raises(RuntimeError, match="Load failed"): |
| 1249 | + pipeline.generate("test prompt", loras=[dynamic_lora]) |
| 1250 | + |
| 1251 | + mock_restore.assert_called_once() |
| 1252 | + |
| 1253 | + def test_generate_cleanup_on_generation_failure(self): |
| 1254 | + """generate() cleans up dynamic LoRAs even if generation fails.""" |
| 1255 | + pipeline = self._create_pipeline_with_mocks() |
| 1256 | + |
| 1257 | + dynamic_lora = LoraConfig( |
| 1258 | + name="dynamic-lora", source=LoraSource.LOCAL, path="/dynamic.safetensors" |
| 1259 | + ) |
| 1260 | + dynamic_lora._resolved_path = Path("/dynamic.safetensors") |
| 1261 | + |
| 1262 | + pipeline.pipe.side_effect = RuntimeError("Generation failed") |
| 1263 | + |
| 1264 | + with ( |
| 1265 | + patch("oneiro.pipelines.civitai_checkpoint.torch"), |
| 1266 | + patch.object(pipeline, "_encode_prompts_to_embeddings"), |
| 1267 | + patch.object(pipeline, "_load_dynamic_loras"), |
| 1268 | + patch.object(pipeline, "_restore_static_loras") as mock_restore, |
| 1269 | + ): |
| 1270 | + with pytest.raises(RuntimeError, match="Generation failed"): |
| 1271 | + pipeline.generate("test prompt", loras=[dynamic_lora]) |
| 1272 | + |
| 1273 | + mock_restore.assert_called_once() |
| 1274 | + |
| 1275 | + def test_generate_without_dynamic_loras_skips_lora_handling(self): |
| 1276 | + """generate() skips LoRA handling when no dynamic LoRAs provided.""" |
| 1277 | + pipeline = self._create_pipeline_with_mocks() |
| 1278 | + |
| 1279 | + with ( |
| 1280 | + patch("oneiro.pipelines.civitai_checkpoint.torch"), |
| 1281 | + patch.object(pipeline, "_encode_prompts_to_embeddings"), |
| 1282 | + patch.object(pipeline, "_load_dynamic_loras") as mock_load, |
| 1283 | + patch.object(pipeline, "_restore_static_loras") as mock_restore, |
| 1284 | + ): |
| 1285 | + pipeline.generate("test prompt") |
| 1286 | + |
| 1287 | + mock_load.assert_not_called() |
| 1288 | + mock_restore.assert_not_called() |
| 1289 | + |
| 1290 | + def test_load_dynamic_loras_respects_cpu_offload(self): |
| 1291 | + """_load_dynamic_loras() skips .to(device) when cpu_offload enabled.""" |
| 1292 | + pipeline = self._create_pipeline_with_mocks() |
| 1293 | + pipeline._cpu_offload = True |
| 1294 | + |
| 1295 | + lora = LoraConfig(name="test-lora", source=LoraSource.LOCAL, path="/fake.safetensors") |
| 1296 | + lora._resolved_path = Path("/fake.safetensors") |
| 1297 | + |
| 1298 | + with ( |
| 1299 | + patch.object(pipeline, "unload_loras"), |
| 1300 | + patch.object(pipeline, "load_single_lora", return_value="test-lora"), |
| 1301 | + patch.object(pipeline, "set_lora_adapters"), |
| 1302 | + ): |
| 1303 | + pipeline._load_dynamic_loras([lora]) |
| 1304 | + |
| 1305 | + pipeline.pipe.to.assert_not_called() |
| 1306 | + |
| 1307 | + def test_load_dynamic_loras_moves_to_device_without_cpu_offload(self): |
| 1308 | + """_load_dynamic_loras() calls .to(device) when cpu_offload disabled.""" |
| 1309 | + pipeline = self._create_pipeline_with_mocks() |
| 1310 | + pipeline._cpu_offload = False |
| 1311 | + pipeline._device = "cuda" |
| 1312 | + |
| 1313 | + lora = LoraConfig(name="test-lora", source=LoraSource.LOCAL, path="/fake.safetensors") |
| 1314 | + lora._resolved_path = Path("/fake.safetensors") |
| 1315 | + |
| 1316 | + with ( |
| 1317 | + patch.object(pipeline, "unload_loras"), |
| 1318 | + patch.object(pipeline, "load_single_lora", return_value="test-lora"), |
| 1319 | + patch.object(pipeline, "set_lora_adapters"), |
| 1320 | + ): |
| 1321 | + pipeline._load_dynamic_loras([lora]) |
| 1322 | + |
| 1323 | + pipeline.pipe.to.assert_called_once_with("cuda") |
0 commit comments