@@ -748,44 +748,78 @@ def _validate_looks_like_control_lora(cls, mod: ModelOnDisk) -> None:
748
748
raise NotAMatch (cls , "model state dict does not look like a Flux Control LoRA" )
749
749
750
750
751
- # LoRADiffusers_SupportedBases: TypeAlias = Literal[
752
- # BaseModelType.StableDiffusion1,
753
- # BaseModelType.StableDiffusion2,
754
- # BaseModelType.StableDiffusionXL,
755
- # BaseModelType.Flux,
756
- # ]
751
+ class LoRA_Diffusers_Config_Base (LoRAConfigBase ):
752
+ """Model config for LoRA/Diffusers models."""
757
753
754
+ # TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates
755
+ # the weights format. FLUX Diffusers LoRAs are single files.
758
756
759
- # class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
760
- # """Model config for LoRA/Diffusers models."""
757
+ format : Literal [ModelFormat .Diffusers ] = Field (default = ModelFormat .Diffusers )
761
758
762
- # # TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates
763
- # # the weights format. FLUX Diffusers LoRAs are single files.
759
+ @classmethod
760
+ def from_model_on_disk (cls , mod : ModelOnDisk , fields : dict [str , Any ]) -> Self :
761
+ _validate_is_dir (cls , mod )
764
762
765
- # base: LoRADiffusers_SupportedBases = Field()
766
- # format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
763
+ _validate_override_fields (cls , fields )
767
764
768
- # @classmethod
769
- # def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
770
- # _validate_is_dir(cls, mod)
765
+ cls ._validate_base (mod )
771
766
772
- # _validate_override_fields( cls, fields)
767
+ return cls ( ** fields )
773
768
774
- # cls._validate_looks_like_diffusers_lora(mod)
769
+ @classmethod
770
+ def _validate_base (cls , mod : ModelOnDisk ) -> None :
771
+ """Raise `NotAMatch` if the model base does not match this config class."""
772
+ expected_base = cls .model_fields ["base" ].default .value
773
+ recognized_base = cls ._get_base_or_raise (mod )
774
+ if expected_base is not recognized_base :
775
+ raise NotAMatch (cls , f"base is { recognized_base } , not { expected_base } " )
775
776
776
- # return cls(**fields)
777
+ @classmethod
778
+ def _get_base_or_raise (cls , mod : ModelOnDisk ) -> BaseModelType :
779
+ if _get_flux_lora_format (mod ):
780
+ return BaseModelType .Flux
777
781
778
- # @classmethod
779
- # def _validate_looks_like_diffusers_lora(cls, mod: ModelOnDisk) -> None:
780
- # suffixes = ["bin", "safetensors"]
781
- # weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
782
- # has_lora_weight_file = any(wf.exists() for wf in weight_files)
783
- # if not has_lora_weight_file:
784
- # raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors")
782
+ # If we've gotten here, we assume that the LoRA is a Stable Diffusion LoRA
783
+ path_to_weight_file = cls ._get_weight_file_or_raise (mod )
784
+ state_dict = mod .load_state_dict (path_to_weight_file )
785
+ token_vector_length = lora_token_vector_length (state_dict )
785
786
786
- # flux_lora_format = _get_flux_lora_format(mod)
787
- # if flux_lora_format is not FluxLoRAFormat.Diffusers:
788
- # raise NotAMatch(cls, "model does not look like a FLUX Diffusers LoRA")
787
+ match token_vector_length :
788
+ case 768 :
789
+ return BaseModelType .StableDiffusion1
790
+ case 1024 :
791
+ return BaseModelType .StableDiffusion2
792
+ case 1280 :
793
+ return BaseModelType .StableDiffusionXL # recognizes format at https://civitai.com/models/224641
794
+ case 2048 :
795
+ return BaseModelType .StableDiffusionXL
796
+ case _:
797
+ raise NotAMatch (cls , f"unrecognized token vector length { token_vector_length } " )
798
+
799
+ @classmethod
800
+ def _get_weight_file_or_raise (cls , mod : ModelOnDisk ) -> Path :
801
+ suffixes = ["bin" , "safetensors" ]
802
+ weight_files = [mod .path / f"pytorch_lora_weights.{ sfx } " for sfx in suffixes ]
803
+ for wf in weight_files :
804
+ if wf .exists ():
805
+ return wf
806
+ raise NotAMatch (cls , "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors" )
807
+
808
+
809
+ class LoRA_SD1_Diffusers_Config (LoRA_Diffusers_Config_Base , ModelConfigBase ):
810
+ base : Literal [BaseModelType .StableDiffusion1 ] = Field (default = BaseModelType .StableDiffusion1 )
811
+
812
+
813
+ class LoRA_SD2_Diffusers_Config (LoRA_Diffusers_Config_Base , ModelConfigBase ):
814
+ base : Literal [BaseModelType .StableDiffusion2 ] = Field (default = BaseModelType .StableDiffusion2 )
815
+
816
+
817
+ class LoRA_SDXL_Diffusers_Config (LoRA_Diffusers_Config_Base , ModelConfigBase ):
818
+ base : Literal [BaseModelType .StableDiffusionXL ] = Field (default = BaseModelType .StableDiffusionXL )
819
+
820
+
821
+ class LoRA_FLUX_Diffusers_Config (LoRA_Diffusers_Config_Base , ModelConfigBase ):
822
+ base : Literal [BaseModelType .Flux ] = Field (default = BaseModelType .Flux )
789
823
790
824
791
825
class VAE_Checkpoint_Config_Base (CheckpointConfigBase ):
@@ -2332,8 +2366,11 @@ def get_model_discriminator_value(v: Any) -> str:
2332
2366
# LoRA - OMI format
2333
2367
Annotated [LoRA_OMI_SDXL_Config , LoRA_OMI_SDXL_Config .get_tag ()],
2334
2368
Annotated [LoRA_OMI_FLUX_Config , LoRA_OMI_FLUX_Config .get_tag ()],
2335
- # LoRA - diffusers format (TODO)
2336
- # Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
2369
+ # LoRA - diffusers format
2370
+ Annotated [LoRA_SD1_Diffusers_Config , LoRA_SD1_Diffusers_Config .get_tag ()],
2371
+ Annotated [LoRA_SD2_Diffusers_Config , LoRA_SD2_Diffusers_Config .get_tag ()],
2372
+ Annotated [LoRA_SDXL_Diffusers_Config , LoRA_SDXL_Diffusers_Config .get_tag ()],
2373
+ Annotated [LoRA_FLUX_Diffusers_Config , LoRA_FLUX_Diffusers_Config .get_tag ()],
2337
2374
# ControlLoRA - diffusers format
2338
2375
Annotated [ControlLoRA_LyCORIS_FLUX_Config , ControlLoRA_LyCORIS_FLUX_Config .get_tag ()],
2339
2376
Annotated [T5Encoder_T5Encoder_Config , T5Encoder_T5Encoder_Config .get_tag ()],
0 commit comments