Skip to content

Commit c6d8d3f

Browse files
committed
refactor(lora): consolidate LoRA/embedding
Extract is_lora_compatible() and is_embedding_compatible() into single shared is_resource_compatible() function. Both functions were identical, doing the same base model matching against PIPELINE_BASE_MODEL_MAP. Changes: - Rename is_lora_compatible → is_resource_compatible in lora.py - Update embedding.py to import from lora.py instead of duplicating
1 parent 568b838 commit c6d8d3f

File tree

6 files changed

+47
-74
lines changed

6 files changed

+47
-74
lines changed

src/oneiro/discord/commands.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from oneiro.discord.handlers import DreamContext, create_dream_callbacks
1111
from oneiro.pipelines import SCHEDULER_CHOICES
1212
from oneiro.pipelines.civitai_checkpoint import CivitaiCheckpointPipeline
13-
from oneiro.pipelines.lora import is_lora_compatible
13+
from oneiro.pipelines.lora import is_resource_compatible
1414
from oneiro.queue import QueueStatus
1515
from oneiro.services.generation import (
1616
MAX_GUIDANCE_SCALE,
@@ -633,7 +633,7 @@ async def fetch_command(
633633
# Check compatibility and prepare warning
634634
compatibility_warning = ""
635635
if model_type == "LORA" and pipeline_type and version.base_model:
636-
if not is_lora_compatible(pipeline_type, version.base_model):
636+
if not is_resource_compatible(pipeline_type, version.base_model):
637637
compatibility_warning = (
638638
f"\n⚠️ **Note**: This LoRA (base: {version.base_model}) may not be "
639639
f"compatible with the current model ({pipeline_type})"

src/oneiro/pipelines/embedding.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any
77

88
from oneiro.civitai import CivitaiClient, parse_civitai_url
9-
from oneiro.pipelines.lora import PIPELINE_BASE_MODEL_MAP
9+
from oneiro.pipelines.lora import is_resource_compatible
1010

1111

1212
class EmbeddingSource(str, Enum):
@@ -281,34 +281,6 @@ def parse_embeddings_from_config(
281281
return embeddings
282282

283283

284-
def is_embedding_compatible(pipeline_type: str, civitai_base_model: str | None) -> bool:
285-
"""Check if a Civitai embedding is compatible with a pipeline type.
286-
287-
Args:
288-
pipeline_type: Pipeline type (flux2, zimage, qwen, etc.)
289-
civitai_base_model: Base model string from Civitai API
290-
291-
Returns:
292-
True if compatible, False otherwise
293-
"""
294-
if civitai_base_model is None:
295-
# Can't verify, assume compatible
296-
return True
297-
298-
compatible_bases = PIPELINE_BASE_MODEL_MAP.get(pipeline_type, [])
299-
if not compatible_bases:
300-
# Unknown pipeline type, assume compatible
301-
return True
302-
303-
# Check if any compatible base model matches (case-insensitive substring)
304-
civitai_lower = civitai_base_model.lower()
305-
for base in compatible_bases:
306-
if base.lower() in civitai_lower or civitai_lower in base.lower():
307-
return True
308-
309-
return False
310-
311-
312284
class EmbeddingIncompatibleError(Exception):
313285
"""Raised when an embedding is incompatible with the pipeline type."""
314286

@@ -382,7 +354,7 @@ async def resolve_embedding_path(
382354

383355
# Validate compatibility
384356
if validate_compatibility and pipeline_type:
385-
if not is_embedding_compatible(pipeline_type, version.base_model):
357+
if not is_resource_compatible(pipeline_type, version.base_model):
386358
raise EmbeddingIncompatibleError(
387359
embedding.name,
388360
pipeline_type,

src/oneiro/pipelines/lora.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -468,8 +468,8 @@ def parse_loras_from_config(
468468
}
469469

470470

471-
def is_lora_compatible(pipeline_type: str, civitai_base_model: str | None) -> bool:
472-
"""Check if a Civitai LoRA is compatible with a pipeline type.
471+
def is_resource_compatible(pipeline_type: str, civitai_base_model: str | None) -> bool:
472+
"""Check if a Civitai resource (LoRA, embedding, etc.) is compatible with a pipeline type.
473473
474474
Args:
475475
pipeline_type: Pipeline type (flux2, zimage, qwen, etc.)
@@ -569,7 +569,7 @@ async def resolve_lora_path(
569569

570570
# Validate compatibility
571571
if validate_compatibility and pipeline_type:
572-
if not is_lora_compatible(pipeline_type, version.base_model):
572+
if not is_resource_compatible(pipeline_type, version.base_model):
573573
raise LoraIncompatibleError(
574574
lora.adapter_name or f"civitai_{lora.civitai_id}",
575575
pipeline_type,

src/oneiro/services/generation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import TYPE_CHECKING, Any
55

66
from oneiro.pipelines import LoraConfig, LoraSource
7-
from oneiro.pipelines.lora import is_lora_compatible
7+
from oneiro.pipelines.lora import is_resource_compatible
88

99
if TYPE_CHECKING:
1010
from oneiro.civitai import CivitaiClient
@@ -184,7 +184,9 @@ async def resolve_loras(
184184
try:
185185
model_info = await civitai_client.get_model(civitai_id)
186186
version = model_info.latest_version
187-
if version and not is_lora_compatible(pipeline_type, version.base_model):
187+
if version and not is_resource_compatible(
188+
pipeline_type, version.base_model
189+
):
188190
result.warnings.append(
189191
f"⚠️ LoRA `{model_info.name}` (base: {version.base_model}) "
190192
f"may not be compatible with current model ({pipeline_type})"
@@ -219,7 +221,7 @@ async def resolve_loras(
219221
try:
220222
model_info = await civitai_client.get_model(civitai_id)
221223
version = model_info.latest_version
222-
if version and not is_lora_compatible(
224+
if version and not is_resource_compatible(
223225
pipeline_type, version.base_model
224226
):
225227
result.warnings.append(

tests/test_embedding.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66
import pytest
77

88
from oneiro.pipelines.embedding import (
9-
PIPELINE_BASE_MODEL_MAP,
109
EmbeddingConfig,
1110
EmbeddingIncompatibleError,
1211
EmbeddingSource,
13-
is_embedding_compatible,
1412
parse_embedding_config,
1513
parse_embeddings_from_config,
1614
)
15+
from oneiro.pipelines.lora import PIPELINE_BASE_MODEL_MAP, is_resource_compatible
1716

1817

1918
class TestEmbeddingConfig:
@@ -341,47 +340,47 @@ def test_no_embeddings(self):
341340

342341

343342
class TestIsEmbeddingCompatible:
344-
"""Tests for is_embedding_compatible function."""
343+
"""Tests for is_resource_compatible function."""
345344

346345
def test_flux1_compatible(self):
347346
"""Flux.1 embeddings compatible with flux1 pipeline."""
348-
assert is_embedding_compatible("flux1", "Flux.1 Dev")
349-
assert is_embedding_compatible("flux1", "Flux.1 Schnell")
350-
assert is_embedding_compatible("flux1", "Flux.1 D")
347+
assert is_resource_compatible("flux1", "Flux.1 Dev")
348+
assert is_resource_compatible("flux1", "Flux.1 Schnell")
349+
assert is_resource_compatible("flux1", "Flux.1 D")
351350

352351
def test_flux2_compatible(self):
353352
"""Flux.2 embeddings compatible with flux2 pipeline."""
354-
assert is_embedding_compatible("flux2", "Flux.2")
353+
assert is_resource_compatible("flux2", "Flux.2")
355354

356355
def test_flux1_flux2_incompatible(self):
357356
"""Flux.1 and Flux.2 embeddings are NOT cross-compatible."""
358-
assert not is_embedding_compatible("flux1", "Flux.2")
359-
assert not is_embedding_compatible("flux2", "Flux.1 Dev")
357+
assert not is_resource_compatible("flux1", "Flux.2")
358+
assert not is_resource_compatible("flux2", "Flux.1 Dev")
360359

361360
def test_sdxl_compatible(self):
362361
"""SDXL embeddings compatible with sdxl pipeline."""
363-
assert is_embedding_compatible("sdxl", "SDXL 1.0")
364-
assert is_embedding_compatible("sdxl", "Pony")
365-
assert is_embedding_compatible("sdxl", "Illustrious")
362+
assert is_resource_compatible("sdxl", "SDXL 1.0")
363+
assert is_resource_compatible("sdxl", "Pony")
364+
assert is_resource_compatible("sdxl", "Illustrious")
366365

367366
def test_incompatible_base_model(self):
368367
"""Incompatible base model returns False."""
369-
assert not is_embedding_compatible("flux2", "SDXL 1.0")
370-
assert not is_embedding_compatible("sdxl", "Flux.1 Dev")
371-
assert not is_embedding_compatible("flux2", "SD 1.5")
368+
assert not is_resource_compatible("flux2", "SDXL 1.0")
369+
assert not is_resource_compatible("sdxl", "Flux.1 Dev")
370+
assert not is_resource_compatible("flux2", "SD 1.5")
372371

373372
def test_none_base_model_is_compatible(self):
374373
"""None base model assumed compatible."""
375-
assert is_embedding_compatible("flux2", None)
374+
assert is_resource_compatible("flux2", None)
376375

377376
def test_unknown_pipeline_is_compatible(self):
378377
"""Unknown pipeline type assumed compatible."""
379-
assert is_embedding_compatible("unknown", "SDXL 1.0")
378+
assert is_resource_compatible("unknown", "SDXL 1.0")
380379

381380
def test_case_insensitive(self):
382381
"""Comparison is case-insensitive."""
383-
assert is_embedding_compatible("flux1", "flux.1 dev")
384-
assert is_embedding_compatible("flux1", "FLUX.1 DEV")
382+
assert is_resource_compatible("flux1", "flux.1 dev")
383+
assert is_resource_compatible("flux1", "FLUX.1 DEV")
385384

386385

387386
class TestEmbeddingIncompatibleError:

tests/test_lora.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
LoraConfig,
1111
LoraIncompatibleError,
1212
LoraSource,
13-
is_lora_compatible,
13+
is_resource_compatible,
1414
parse_lora_config,
1515
parse_loras_from_config,
1616
parse_loras_from_model_config,
@@ -574,40 +574,40 @@ def test_mixed_string_and_dict_refs(self):
574574

575575

576576
class TestIsLoraCompatible:
577-
"""Tests for is_lora_compatible function."""
577+
"""Tests for is_resource_compatible function."""
578578

579579
def test_flux_compatible(self):
580580
"""Flux.1 LoRAs compatible with flux1 pipeline (not flux2)."""
581-
assert is_lora_compatible("flux1", "Flux.1 Dev")
582-
assert is_lora_compatible("flux1", "Flux.1 Schnell")
583-
assert is_lora_compatible("flux1", "Flux.1 D")
584-
assert is_lora_compatible("flux2", "Flux.2")
585-
assert not is_lora_compatible("flux2", "Flux.1 Dev")
581+
assert is_resource_compatible("flux1", "Flux.1 Dev")
582+
assert is_resource_compatible("flux1", "Flux.1 Schnell")
583+
assert is_resource_compatible("flux1", "Flux.1 D")
584+
assert is_resource_compatible("flux2", "Flux.2")
585+
assert not is_resource_compatible("flux2", "Flux.1 Dev")
586586

587587
def test_sdxl_compatible(self):
588588
"""SDXL LoRAs compatible with sdxl pipeline."""
589-
assert is_lora_compatible("sdxl", "SDXL 1.0")
590-
assert is_lora_compatible("sdxl", "Pony")
591-
assert is_lora_compatible("sdxl", "Illustrious")
589+
assert is_resource_compatible("sdxl", "SDXL 1.0")
590+
assert is_resource_compatible("sdxl", "Pony")
591+
assert is_resource_compatible("sdxl", "Illustrious")
592592

593593
def test_incompatible_base_model(self):
594594
"""Incompatible base model returns False."""
595-
assert not is_lora_compatible("flux2", "SDXL 1.0")
596-
assert not is_lora_compatible("sdxl", "Flux.1 Dev")
597-
assert not is_lora_compatible("flux2", "SD 1.5")
595+
assert not is_resource_compatible("flux2", "SDXL 1.0")
596+
assert not is_resource_compatible("sdxl", "Flux.1 Dev")
597+
assert not is_resource_compatible("flux2", "SD 1.5")
598598

599599
def test_none_base_model_is_compatible(self):
600600
"""None base model assumed compatible."""
601-
assert is_lora_compatible("flux2", None)
601+
assert is_resource_compatible("flux2", None)
602602

603603
def test_unknown_pipeline_is_compatible(self):
604604
"""Unknown pipeline type assumed compatible."""
605-
assert is_lora_compatible("unknown", "SDXL 1.0")
605+
assert is_resource_compatible("unknown", "SDXL 1.0")
606606

607607
def test_case_insensitive(self):
608608
"""Comparison is case-insensitive."""
609-
assert is_lora_compatible("flux1", "flux.1 dev")
610-
assert is_lora_compatible("flux1", "FLUX.1 DEV")
609+
assert is_resource_compatible("flux1", "flux.1 dev")
610+
assert is_resource_compatible("flux1", "FLUX.1 DEV")
611611

612612

613613
class TestLoraIncompatibleError:

0 commit comments

Comments
 (0)