Skip to content

Commit 17c5ad2

Browse files
refactor(mm): diffusers loras
w
1 parent 935fafe commit 17c5ad2

File tree

9 files changed

+100
-44
lines changed

9 files changed

+100
-44
lines changed

invokeai/app/api/routers/model_manager.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,13 @@
2929
)
3030
from invokeai.app.util.suppress_output import SuppressOutput
3131
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType
32-
from invokeai.backend.model_manager.config import AnyModelConfig, SD_1_2_XL_XLRefiner_CheckpointConfig
32+
from invokeai.backend.model_manager.config import (
33+
AnyModelConfig,
34+
Main_SD1_Checkpoint_Config,
35+
Main_SD2_Checkpoint_Config,
36+
Main_SDXL_Checkpoint_Config,
37+
Main_SDXLRefiner_Checkpoint_Config,
38+
)
3339
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
3440
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
3541
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
@@ -738,7 +744,15 @@ async def convert_model(
738744
logger.error(str(e))
739745
raise HTTPException(status_code=424, detail=str(e))
740746

741-
if isinstance(model_config, SD_1_2_XL_XLRefiner_CheckpointConfig):
747+
if isinstance(
748+
model_config,
749+
(
750+
Main_SD1_Checkpoint_Config,
751+
Main_SD2_Checkpoint_Config,
752+
Main_SDXL_Checkpoint_Config,
753+
Main_SDXLRefiner_Checkpoint_Config,
754+
),
755+
):
742756
msg = f"The model with key {key} is not a main SD 1/2/XL checkpoint model."
743757
logger.error(msg)
744758
raise HTTPException(400, msg)

invokeai/app/invocations/flux_ip_adapter.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
1818
from invokeai.app.services.shared.invocation_context import InvocationContext
1919
from invokeai.backend.model_manager.config import (
20-
IPAdapter_InvokeAI_Config_Base,
21-
IPAdapterCheckpointConfig,
20+
IPAdapter_FLUX_Checkpoint_Config,
2221
)
2322
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
2423

@@ -68,7 +67,7 @@ def validate_begin_end_step_percent(self) -> Self:
6867
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
6968
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
7069
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
71-
assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapterCheckpointConfig))
70+
assert isinstance(ip_adapter_info, IPAdapter_FLUX_Checkpoint_Config)
7271

7372
# Note: There is a IPAdapterInvokeAIConfig.image_encoder_model_id field, but it isn't trustworthy.
7473
image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model]

invokeai/app/invocations/ip_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from invokeai.app.services.shared.invocation_context import InvocationContext
1414
from invokeai.backend.model_manager.config import (
1515
AnyModelConfig,
16+
IPAdapter_Checkpoint_Config_Base,
1617
IPAdapter_InvokeAI_Config_Base,
17-
IPAdapterCheckpointConfig,
1818
)
1919
from invokeai.backend.model_manager.starter_models import (
2020
StarterModel,
@@ -123,7 +123,7 @@ def validate_begin_end_step_percent(self) -> Self:
123123
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
124124
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
125125
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
126-
assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapterCheckpointConfig))
126+
assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapter_Checkpoint_Config_Base))
127127

128128
if isinstance(ip_adapter_info, IPAdapter_InvokeAI_Config_Base):
129129
image_encoder_model_id = ip_adapter_info.image_encoder_model_id

invokeai/backend/model_manager/config.py

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -748,44 +748,78 @@ def _validate_looks_like_control_lora(cls, mod: ModelOnDisk) -> None:
748748
raise NotAMatch(cls, "model state dict does not look like a Flux Control LoRA")
749749

750750

751-
# LoRADiffusers_SupportedBases: TypeAlias = Literal[
752-
# BaseModelType.StableDiffusion1,
753-
# BaseModelType.StableDiffusion2,
754-
# BaseModelType.StableDiffusionXL,
755-
# BaseModelType.Flux,
756-
# ]
751+
class LoRA_Diffusers_Config_Base(LoRAConfigBase):
752+
"""Model config for LoRA/Diffusers models."""
757753

754+
# TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates
755+
# the weights format. FLUX Diffusers LoRAs are single files.
758756

759-
# class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
760-
# """Model config for LoRA/Diffusers models."""
757+
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
761758

762-
# # TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates
763-
# # the weights format. FLUX Diffusers LoRAs are single files.
759+
@classmethod
760+
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
761+
_validate_is_dir(cls, mod)
764762

765-
# base: LoRADiffusers_SupportedBases = Field()
766-
# format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
763+
_validate_override_fields(cls, fields)
767764

768-
# @classmethod
769-
# def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
770-
# _validate_is_dir(cls, mod)
765+
cls._validate_base(mod)
771766

772-
# _validate_override_fields(cls, fields)
767+
return cls(**fields)
773768

774-
# cls._validate_looks_like_diffusers_lora(mod)
769+
@classmethod
770+
def _validate_base(cls, mod: ModelOnDisk) -> None:
771+
"""Raise `NotAMatch` if the model base does not match this config class."""
772+
expected_base = cls.model_fields["base"].default.value
773+
recognized_base = cls._get_base_or_raise(mod)
774+
if expected_base is not recognized_base:
775+
raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}")
775776

776-
# return cls(**fields)
777+
@classmethod
778+
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
779+
if _get_flux_lora_format(mod):
780+
return BaseModelType.Flux
777781

778-
# @classmethod
779-
# def _validate_looks_like_diffusers_lora(cls, mod: ModelOnDisk) -> None:
780-
# suffixes = ["bin", "safetensors"]
781-
# weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
782-
# has_lora_weight_file = any(wf.exists() for wf in weight_files)
783-
# if not has_lora_weight_file:
784-
# raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors")
782+
# If we've gotten here, we assume that the LoRA is a Stable Diffusion LoRA
783+
path_to_weight_file = cls._get_weight_file_or_raise(mod)
784+
state_dict = mod.load_state_dict(path_to_weight_file)
785+
token_vector_length = lora_token_vector_length(state_dict)
785786

786-
# flux_lora_format = _get_flux_lora_format(mod)
787-
# if flux_lora_format is not FluxLoRAFormat.Diffusers:
788-
# raise NotAMatch(cls, "model does not look like a FLUX Diffusers LoRA")
787+
match token_vector_length:
788+
case 768:
789+
return BaseModelType.StableDiffusion1
790+
case 1024:
791+
return BaseModelType.StableDiffusion2
792+
case 1280:
793+
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
794+
case 2048:
795+
return BaseModelType.StableDiffusionXL
796+
case _:
797+
raise NotAMatch(cls, f"unrecognized token vector length {token_vector_length}")
798+
799+
@classmethod
800+
def _get_weight_file_or_raise(cls, mod: ModelOnDisk) -> Path:
801+
suffixes = ["bin", "safetensors"]
802+
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
803+
for wf in weight_files:
804+
if wf.exists():
805+
return wf
806+
raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors")
807+
808+
809+
class LoRA_SD1_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
810+
base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1)
811+
812+
813+
class LoRA_SD2_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
814+
base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2)
815+
816+
817+
class LoRA_SDXL_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
818+
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
819+
820+
821+
class LoRA_FLUX_Diffusers_Config(LoRA_Diffusers_Config_Base, ModelConfigBase):
822+
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
789823

790824

791825
class VAE_Checkpoint_Config_Base(CheckpointConfigBase):
@@ -2332,8 +2366,11 @@ def get_model_discriminator_value(v: Any) -> str:
23322366
# LoRA - OMI format
23332367
Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()],
23342368
Annotated[LoRA_OMI_FLUX_Config, LoRA_OMI_FLUX_Config.get_tag()],
2335-
# LoRA - diffusers format (TODO)
2336-
# Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
2369+
# LoRA - diffusers format
2370+
Annotated[LoRA_SD1_Diffusers_Config, LoRA_SD1_Diffusers_Config.get_tag()],
2371+
Annotated[LoRA_SD2_Diffusers_Config, LoRA_SD2_Diffusers_Config.get_tag()],
2372+
Annotated[LoRA_SDXL_Diffusers_Config, LoRA_SDXL_Diffusers_Config.get_tag()],
2373+
Annotated[LoRA_FLUX_Diffusers_Config, LoRA_FLUX_Diffusers_Config.get_tag()],
23372374
# ControlLoRA - diffusers format
23382375
Annotated[ControlLoRA_LyCORIS_FLUX_Config, ControlLoRA_LyCORIS_FLUX_Config.get_tag()],
23392376
Annotated[T5Encoder_T5Encoder_Config, T5Encoder_T5Encoder_Config.get_tag()],

invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
def is_state_dict_likely_in_flux_aitoolkit_format(
16-
state_dict: dict[str, Any],
16+
state_dict: dict[str | int, Any],
1717
metadata: dict[str, Any] | None = None,
1818
) -> bool:
1919
if metadata:
@@ -23,7 +23,7 @@ def is_state_dict_likely_in_flux_aitoolkit_format(
2323
return False
2424
return software.get("name") == "ai-toolkit"
2525
# metadata got lost somewhere
26-
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys())
26+
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys() if isinstance(k, str))
2727

2828

2929
@dataclass

invokeai/backend/patches/lora_conversions/flux_control_lora_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ def is_state_dict_likely_flux_control(state_dict: dict[str | int, Any]) -> bool:
2525
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
2626
"""
2727

28-
all_keys_match = all(re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, str(k)) for k in state_dict.keys())
28+
all_keys_match = all(
29+
re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k) for k in state_dict.keys() if isinstance(k, str)
30+
)
2931

3032
# Check the shape of the img_in weight, because this layer shape is modified by FLUX control LoRAs.
3133
lora_a_weight = state_dict.get("img_in.lora_A.weight", None)

invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@
99
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
1010

1111

12-
def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
12+
def is_state_dict_likely_in_flux_diffusers_format(state_dict: dict[str | int, torch.Tensor]) -> bool:
1313
"""Checks if the provided state dict is likely in the Diffusers FLUX LoRA format.
1414
1515
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. (A
1616
perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.)
1717
"""
1818
# First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format).
19-
all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys())
19+
all_keys_in_peft_format = all(
20+
k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys() if isinstance(k, str)
21+
)
2022

2123
# Check if keys use transformer prefix
2224
transformer_prefix_keys = [

invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
FLUX_KOHYA_T5_KEY_REGEX = r"lora_te2_encoder_block_(\d+)_layer_(\d+)_(DenseReluDense|SelfAttention)_(\w+)_?(\w+)?\.?.*"
4545

4646

47-
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
47+
def is_state_dict_likely_in_flux_kohya_format(state_dict: dict[str | int, Any]) -> bool:
4848
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
4949
5050
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
@@ -56,6 +56,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
5656
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
5757
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
5858
for k in state_dict.keys()
59+
if isinstance(k, str)
5960
)
6061

6162

invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
)
4141

4242

43-
def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) -> bool:
43+
def is_state_dict_likely_in_flux_onetrainer_format(state_dict: dict[str | int, Any]) -> bool:
4444
"""Checks if the provided state dict is likely in the OneTrainer FLUX LoRA format.
4545
4646
This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A
@@ -53,6 +53,7 @@ def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) -
5353
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
5454
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
5555
for k in state_dict.keys()
56+
if isinstance(k, str)
5657
)
5758

5859

0 commit comments

Comments
 (0)