44
44
45
45
from invokeai .app .services .config .config_default import get_config
46
46
from invokeai .app .util .misc import uuid_string
47
+ from invokeai .backend .flux .controlnet .state_dict_utils import (
48
+ is_state_dict_instantx_controlnet ,
49
+ is_state_dict_xlabs_controlnet ,
50
+ )
47
51
from invokeai .backend .flux .ip_adapter .state_dict_utils import is_state_dict_xlabs_ip_adapter
48
52
from invokeai .backend .flux .redux .flux_redux_state_dict_utils import is_state_dict_likely_flux_redux
49
53
from invokeai .backend .model_hash .hash_validator import validate_hash
@@ -759,13 +763,56 @@ def _get_base_or_raise(cls, mod: ModelOnDisk) -> VAEDiffusersConfig_SupportedBas
759
763
]
760
764
761
765
762
- class ControlNetDiffusersConfig (DiffusersConfigBase , ControlAdapterConfigBase , LegacyProbeMixin , ModelConfigBase ):
766
+ class ControlNetDiffusersConfig (DiffusersConfigBase , ControlAdapterConfigBase , ModelConfigBase ):
763
767
"""Model config for ControlNet models (diffusers version)."""
764
768
765
769
base : ControlNetDiffusers_SupportedBases = Field ()
766
770
type : Literal [ModelType .ControlNet ] = Field (default = ModelType .ControlNet )
767
771
format : Literal [ModelFormat .Diffusers ] = Field (default = ModelFormat .Diffusers )
768
772
773
+ VALID_OVERRIDES : ClassVar = {
774
+ "type" : ModelType .ControlNet ,
775
+ "format" : ModelFormat .Diffusers ,
776
+ }
777
+
778
+ VALID_CLASS_NAMES : ClassVar = {
779
+ "ControlNetModel" ,
780
+ "FluxControlNetModel" ,
781
+ }
782
+
783
+ @classmethod
784
+ def from_model_on_disk (cls , mod : ModelOnDisk , fields : dict [str , Any ]) -> Self :
785
+ _raise_if_not_dir (cls , mod )
786
+
787
+ _validate_overrides (cls , fields , cls .VALID_OVERRIDES )
788
+
789
+ _validate_class_names (cls , mod .path / "config.json" , cls .VALID_CLASS_NAMES )
790
+
791
+ base = fields .get ("base" ) or cls ._get_base_or_raise (mod )
792
+
793
+ return cls (** fields , base = base )
794
+
795
+ @classmethod
796
+ def _get_base_or_raise (cls , mod : ModelOnDisk ) -> ControlNetDiffusers_SupportedBases :
797
+ config = _get_config_or_raise (cls , mod .path / "config.json" )
798
+
799
+ if config .get ("_class_name" ) == "FluxControlNetModel" :
800
+ return BaseModelType .Flux
801
+
802
+ dimension = config .get ("cross_attention_dim" )
803
+
804
+ match dimension :
805
+ case 768 :
806
+ return BaseModelType .StableDiffusion1
807
+ case 1024 :
808
+ # No obvious way to distinguish between sd2-base and sd2-768, but we don't really differentiate them
809
+ # anyway.
810
+ return BaseModelType .StableDiffusion2
811
+ case 2048 :
812
+ return BaseModelType .StableDiffusionXL
813
+ case _:
814
+ raise NotAMatch (cls , f"unrecognized cross_attention_dim { dimension } " )
815
+
769
816
770
817
ControlNetCheckpoint_SupportedBases : TypeAlias = Literal [
771
818
BaseModelType .StableDiffusion1 ,
@@ -775,13 +822,75 @@ class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, L
775
822
]
776
823
777
824
778
- class ControlNetCheckpointConfig (CheckpointConfigBase , ControlAdapterConfigBase , LegacyProbeMixin , ModelConfigBase ):
825
+ class ControlNetCheckpointConfig (CheckpointConfigBase , ControlAdapterConfigBase , ModelConfigBase ):
779
826
"""Model config for ControlNet models (diffusers version)."""
780
827
781
828
base : ControlNetDiffusers_SupportedBases = Field ()
782
829
type : Literal [ModelType .ControlNet ] = Field (default = ModelType .ControlNet )
783
830
format : Literal [ModelFormat .Checkpoint ] = Field (default = ModelFormat .Checkpoint )
784
831
832
+ VALID_OVERRIDES : ClassVar = {
833
+ "type" : ModelType .ControlNet ,
834
+ "format" : ModelFormat .Checkpoint ,
835
+ }
836
+
837
+ @classmethod
838
+ def from_model_on_disk (cls , mod : ModelOnDisk , fields : dict [str , Any ]) -> Self :
839
+ _raise_if_not_file (cls , mod )
840
+
841
+ _validate_overrides (cls , fields , cls .VALID_OVERRIDES )
842
+
843
+ if not mod .has_keys_starting_with (
844
+ {
845
+ "controlnet" ,
846
+ "control_model" ,
847
+ "input_blocks" ,
848
+ # XLabs FLUX ControlNet models have keys starting with "controlnet_blocks."
849
+ # For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
850
+ # TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with
851
+ # "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so
852
+ # delicate.
853
+ "controlnet_blocks" ,
854
+ }
855
+ ):
856
+ raise NotAMatch (cls , "state dict does not look like a ControlNet checkpoint" )
857
+
858
+ base = fields .get ("base" ) or cls ._get_base_or_raise (mod )
859
+
860
+ return cls (** fields , base = base )
861
+
862
+ @classmethod
863
+ def _get_base_or_raise (cls , mod : ModelOnDisk ) -> ControlNetCheckpoint_SupportedBases :
864
+ state_dict = mod .load_state_dict ()
865
+
866
+ if is_state_dict_xlabs_controlnet (state_dict ) or is_state_dict_instantx_controlnet (state_dict ):
867
+ # TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing
868
+ # get_format()?
869
+ return BaseModelType .Flux
870
+
871
+ for key in (
872
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" ,
873
+ "controlnet_mid_block.bias" ,
874
+ "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" ,
875
+ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight" ,
876
+ ):
877
+ if key not in state_dict :
878
+ continue
879
+ width = state_dict [key ].shape [- 1 ]
880
+ match width :
881
+ case 768 :
882
+ return BaseModelType .StableDiffusion1
883
+ case 1024 :
884
+ return BaseModelType .StableDiffusion2
885
+ case 2048 :
886
+ return BaseModelType .StableDiffusionXL
887
+ case 1280 :
888
+ return BaseModelType .StableDiffusionXL
889
+ case _:
890
+ pass
891
+
892
+ raise NotAMatch (cls , "unable to determine base type from state dict" )
893
+
785
894
786
895
TextualInversion_SupportedBases : TypeAlias = Literal [
787
896
BaseModelType .StableDiffusion1 ,
@@ -1247,7 +1356,6 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi
1247
1356
"T2IAdapter" ,
1248
1357
}
1249
1358
1250
-
1251
1359
@classmethod
1252
1360
def from_model_on_disk (cls , mod : ModelOnDisk , fields : dict [str , Any ]) -> Self :
1253
1361
_raise_if_not_dir (cls , mod )
@@ -1276,6 +1384,7 @@ def _get_base_or_raise(cls, mod: ModelOnDisk) -> T2IAdapterDiffusers_SupportedBa
1276
1384
case _:
1277
1385
raise NotAMatch (cls , f"unrecognized adapter_type '{ adapter_type } '" )
1278
1386
1387
+
1279
1388
class SpandrelImageToImageConfig (ModelConfigBase ):
1280
1389
"""Model config for Spandrel Image to Image models."""
1281
1390
0 commit comments