Skip to content

Commit d27bef1

Browse files
feat(mm): port cnet to new api
1 parent 07e99c9 commit d27bef1

File tree

2 files changed

+114
-5
lines changed

2 files changed

+114
-5
lines changed

invokeai/backend/flux/controlnet/state_dict_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from invokeai.backend.flux.model import FluxParams
66

77

8-
def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool:
8+
def is_state_dict_xlabs_controlnet(sd: dict[str | int, Any]) -> bool:
99
"""Is the state dict for an XLabs ControlNet model?
1010
1111
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
@@ -25,7 +25,7 @@ def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool:
2525
return False
2626

2727

28-
def is_state_dict_instantx_controlnet(sd: Dict[str, Any]) -> bool:
28+
def is_state_dict_instantx_controlnet(sd: dict[str | int, Any]) -> bool:
2929
"""Is the state dict for an InstantX ControlNet model?
3030
3131
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.

invokeai/backend/model_manager/config.py

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@
4444

4545
from invokeai.app.services.config.config_default import get_config
4646
from invokeai.app.util.misc import uuid_string
47+
from invokeai.backend.flux.controlnet.state_dict_utils import (
48+
is_state_dict_instantx_controlnet,
49+
is_state_dict_xlabs_controlnet,
50+
)
4751
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
4852
from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux
4953
from invokeai.backend.model_hash.hash_validator import validate_hash
@@ -759,13 +763,56 @@ def _get_base_or_raise(cls, mod: ModelOnDisk) -> VAEDiffusersConfig_SupportedBas
759763
]
760764

761765

762-
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
766+
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfigBase):
763767
"""Model config for ControlNet models (diffusers version)."""
764768

765769
base: ControlNetDiffusers_SupportedBases = Field()
766770
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
767771
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
768772

773+
VALID_OVERRIDES: ClassVar = {
774+
"type": ModelType.ControlNet,
775+
"format": ModelFormat.Diffusers,
776+
}
777+
778+
VALID_CLASS_NAMES: ClassVar = {
779+
"ControlNetModel",
780+
"FluxControlNetModel",
781+
}
782+
783+
@classmethod
784+
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
785+
_raise_if_not_dir(cls, mod)
786+
787+
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
788+
789+
_validate_class_names(cls, mod.path / "config.json", cls.VALID_CLASS_NAMES)
790+
791+
base = fields.get("base") or cls._get_base_or_raise(mod)
792+
793+
return cls(**fields, base=base)
794+
795+
@classmethod
796+
def _get_base_or_raise(cls, mod: ModelOnDisk) -> ControlNetDiffusers_SupportedBases:
797+
config = _get_config_or_raise(cls, mod.path / "config.json")
798+
799+
if config.get("_class_name") == "FluxControlNetModel":
800+
return BaseModelType.Flux
801+
802+
dimension = config.get("cross_attention_dim")
803+
804+
match dimension:
805+
case 768:
806+
return BaseModelType.StableDiffusion1
807+
case 1024:
808+
# No obvious way to distinguish between sd2-base and sd2-768, but we don't really differentiate them
809+
# anyway.
810+
return BaseModelType.StableDiffusion2
811+
case 2048:
812+
return BaseModelType.StableDiffusionXL
813+
case _:
814+
raise NotAMatch(cls, f"unrecognized cross_attention_dim {dimension}")
815+
769816

770817
ControlNetCheckpoint_SupportedBases: TypeAlias = Literal[
771818
BaseModelType.StableDiffusion1,
@@ -775,13 +822,75 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, L
775822
]
776823

777824

778-
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
825+
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, ModelConfigBase):
779826
"""Model config for ControlNet models (diffusers version)."""
780827

781828
base: ControlNetDiffusers_SupportedBases = Field()
782829
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
783830
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
784831

832+
VALID_OVERRIDES: ClassVar = {
833+
"type": ModelType.ControlNet,
834+
"format": ModelFormat.Checkpoint,
835+
}
836+
837+
@classmethod
838+
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
839+
_raise_if_not_file(cls, mod)
840+
841+
_validate_overrides(cls, fields, cls.VALID_OVERRIDES)
842+
843+
if not mod.has_keys_starting_with(
844+
{
845+
"controlnet",
846+
"control_model",
847+
"input_blocks",
848+
# XLabs FLUX ControlNet models have keys starting with "controlnet_blocks."
849+
# For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
850+
# TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with
851+
# "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so
852+
# delicate.
853+
"controlnet_blocks",
854+
}
855+
):
856+
raise NotAMatch(cls, "state dict does not look like a ControlNet checkpoint")
857+
858+
base = fields.get("base") or cls._get_base_or_raise(mod)
859+
860+
return cls(**fields, base=base)
861+
862+
@classmethod
863+
def _get_base_or_raise(cls, mod: ModelOnDisk) -> ControlNetCheckpoint_SupportedBases:
864+
state_dict = mod.load_state_dict()
865+
866+
if is_state_dict_xlabs_controlnet(state_dict) or is_state_dict_instantx_controlnet(state_dict):
867+
# TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing
868+
# get_format()?
869+
return BaseModelType.Flux
870+
871+
for key in (
872+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
873+
"controlnet_mid_block.bias",
874+
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
875+
"down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
876+
):
877+
if key not in state_dict:
878+
continue
879+
width = state_dict[key].shape[-1]
880+
match width:
881+
case 768:
882+
return BaseModelType.StableDiffusion1
883+
case 1024:
884+
return BaseModelType.StableDiffusion2
885+
case 2048:
886+
return BaseModelType.StableDiffusionXL
887+
case 1280:
888+
return BaseModelType.StableDiffusionXL
889+
case _:
890+
pass
891+
892+
raise NotAMatch(cls, "unable to determine base type from state dict")
893+
785894

786895
TextualInversion_SupportedBases: TypeAlias = Literal[
787896
BaseModelType.StableDiffusion1,
@@ -1247,7 +1356,6 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi
12471356
"T2IAdapter",
12481357
}
12491358

1250-
12511359
@classmethod
12521360
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
12531361
_raise_if_not_dir(cls, mod)
@@ -1276,6 +1384,7 @@ def _get_base_or_raise(cls, mod: ModelOnDisk) -> T2IAdapterDiffusers_SupportedBa
12761384
case _:
12771385
raise NotAMatch(cls, f"unrecognized adapter_type '{adapter_type}'")
12781386

1387+
12791388
class SpandrelImageToImageConfig(ModelConfigBase):
12801389
"""Model config for Spandrel Image to Image models."""
12811390

0 commit comments

Comments
 (0)