Skip to content

Commit c9b2cce

Browse files
brandonrisingmaryhipp
authored andcommitted
Add diffusers config object for control loras
1 parent 401fb39 commit c9b2cce

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,18 @@ def get_tag() -> Tag:
292292
return Tag(f"{ModelType.ControlLoRa.value}.{ModelFormat.LyCORIS.value}")
293293

294294

295+
class ControlLoRADiffusersConfig(ModelConfigBase, ControlAdapterConfigBase):
296+
"""Model config for Control LoRA models."""
297+
298+
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
299+
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
300+
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
301+
302+
@staticmethod
303+
def get_tag() -> Tag:
304+
return Tag(f"{ModelType.ControlLoRa.value}.{ModelFormat.Diffusers.value}")
305+
306+
295307
class LoRADiffusersConfig(LoRAConfigBase):
296308
"""Model config for LoRA/Diffusers models."""
297309

@@ -549,6 +561,7 @@ def get_model_discriminator_value(v: Any) -> str:
549561
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
550562
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
551563
Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()],
564+
Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()],
552565
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
553566
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
554567
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],

invokeai/backend/model_manager/load/model_loaders/lora.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
3939
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
4040
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlLoRa, format=ModelFormat.LyCORIS)
41+
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlLoRa, format=ModelFormat.Diffusers)
4142
class LoRALoader(ModelLoader):
4243
"""Class to load LoRA models."""
4344

0 commit comments

Comments
 (0)