@@ -106,35 +106,67 @@ def __init__(
106
106
DEFAULTS_PRECISION = Literal ["fp16" , "fp32" ]
107
107
108
108
109
- # Utility from https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144
110
- def find_field_schema (model : type [BaseModel ], field_name : str ) -> CoreSchema :
111
- schema : CoreSchema = model .__pydantic_core_schema__ .copy ()
112
- # we shallow copied, be careful not to mutate the original schema!
109
+ class FieldValidator :
110
+ """Utility class for validating individual fields of a Pydantic model without instantiating the whole model.
113
111
114
- assert schema ["type" ] in ["definitions" , "model" ]
112
+ See: https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144
113
+ """
115
114
116
- # find the field schema
117
- field_schema = schema ["schema" ] # type: ignore
118
- while "fields" not in field_schema :
119
- field_schema = field_schema ["schema" ] # type: ignore
115
+ @staticmethod
116
+ def find_field_schema (model : type [BaseModel ], field_name : str ) -> CoreSchema :
117
+ """Find the Pydantic core schema for a specific field in a model."""
118
+ schema : CoreSchema = model .__pydantic_core_schema__ .copy ()
119
+ # we shallow copied, be careful not to mutate the original schema!
120
120
121
- field_schema = field_schema ["fields" ][field_name ]["schema" ] # type: ignore
121
+ assert schema ["type" ] in ["definitions" , "model" ]
122
+
123
+ # find the field schema
124
+ field_schema = schema ["schema" ] # type: ignore
125
+ while "fields" not in field_schema :
126
+ field_schema = field_schema ["schema" ] # type: ignore
127
+
128
+ field_schema = field_schema ["fields" ][field_name ]["schema" ] # type: ignore
129
+
130
+ # if the original schema is a definition schema, replace the model schema with the field schema
131
+ if schema ["type" ] == "definitions" :
132
+ schema ["schema" ] = field_schema
133
+ return schema
134
+ else :
135
+ return field_schema
136
+
137
+ @cache
138
+ @staticmethod
139
+ def get_validator (model : type [BaseModel ], field_name : str ) -> SchemaValidator :
140
+ """Get a SchemaValidator for a specific field in a model."""
141
+ return SchemaValidator (FieldValidator .find_field_schema (model , field_name ))
142
+
143
+ @staticmethod
144
+ def validate_field (model : type [BaseModel ], field_name : str , value : Any ) -> Any :
145
+ """Validate a value for a specific field in a model."""
146
+ return FieldValidator .get_validator (model , field_name ).validate_python (value )
122
147
123
- # if the original schema is a definition schema, replace the model schema with the field schema
124
- if schema ["type" ] == "definitions" :
125
- schema ["schema" ] = field_schema
126
- return schema
127
- else :
128
- return field_schema
129
148
149
+ def has_keys_exact (state_dict : dict [str | int , Any ], keys : str | set [str ]) -> bool :
150
+ """Returns true if the state dict has all of the specified keys."""
151
+ _keys = {keys } if isinstance (keys , str ) else keys
152
+ return _keys .issubset ({key for key in state_dict .keys () if isinstance (key , str )})
130
153
131
- @cache
132
- def validator (model : type [BaseModel ], field_name : str ) -> SchemaValidator :
133
- return SchemaValidator (find_field_schema (model , field_name ))
134
154
155
+ def has_keys_starting_with (state_dict : dict [str | int , Any ], prefixes : str | set [str ]) -> bool :
156
+ """Returns true if the state dict has any keys starting with any of the specified prefixes."""
157
+ _prefixes = {prefixes } if isinstance (prefixes , str ) else prefixes
158
+ return any (any (key .startswith (prefix ) for prefix in _prefixes ) for key in state_dict .keys () if isinstance (key , str ))
135
159
136
- def validate_model_field (model : type [BaseModel ], field_name : str , value : Any ) -> Any :
137
- return validator (model , field_name ).validate_python (value )
160
+
161
+ def has_keys_ending_with (state_dict : dict [str | int , Any ], suffixes : str | set [str ]) -> bool :
162
+ """Returns true if the state dict has any keys ending with any of the specified suffixes."""
163
+ _suffixes = {suffixes } if isinstance (suffixes , str ) else suffixes
164
+ return any (any (key .endswith (suffix ) for suffix in _suffixes ) for key in state_dict .keys () if isinstance (key , str ))
165
+
166
+
167
+ def common_config_paths (path : Path ) -> set [Path ]:
168
+ """Returns common config file paths for models stored in directories."""
169
+ return {path / "config.json" , path / "model_index.json" }
138
170
139
171
140
172
# These utility functions are tightly coupled to the config classes below in order to make the process of raising
@@ -225,7 +257,7 @@ def _validate_override_fields(
225
257
if field_name not in config_class .model_fields :
226
258
raise NotAMatch (config_class , f"unknown override field: { field_name } " )
227
259
try :
228
- validate_model_field (config_class , field_name , override_value )
260
+ FieldValidator . validate_field (config_class , field_name , override_value )
229
261
except ValidationError as e :
230
262
raise NotAMatch (config_class , f"invalid override for field '{ field_name } ': { e } " ) from e
231
263
@@ -440,7 +472,13 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
440
472
441
473
_validate_override_fields (cls , fields )
442
474
443
- _validate_class_name (cls , mod .common_config_paths (), {"T5EncoderModel" })
475
+ _validate_class_name (
476
+ cls ,
477
+ common_config_paths (mod .path ),
478
+ {
479
+ "T5EncoderModel" ,
480
+ },
481
+ )
444
482
445
483
cls ._validate_has_unquantized_config_file (mod )
446
484
@@ -465,7 +503,13 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
465
503
466
504
_validate_override_fields (cls , fields )
467
505
468
- _validate_class_name (cls , mod .common_config_paths (), {"T5EncoderModel" })
506
+ _validate_class_name (
507
+ cls ,
508
+ common_config_paths (mod .path ),
509
+ {
510
+ "T5EncoderModel" ,
511
+ },
512
+ )
469
513
470
514
cls ._validate_filename_looks_like_bnb_quantized (mod )
471
515
@@ -481,7 +525,7 @@ def _validate_filename_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
481
525
482
526
@classmethod
483
527
def _validate_model_looks_like_bnb_quantized (cls , mod : ModelOnDisk ) -> None :
484
- has_scb_key_suffix = mod .has_keys_ending_with ( "SCB" )
528
+ has_scb_key_suffix = has_keys_ending_with ( mod .load_state_dict (), "SCB" )
485
529
if not has_scb_key_suffix :
486
530
raise NotAMatch (cls , "state dict does not look like bnb quantized llm_int8" )
487
531
@@ -592,23 +636,25 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
592
636
def _validate_looks_like_lora (cls , mod : ModelOnDisk ) -> None :
593
637
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
594
638
# Some main models have these keys, likely due to the creator merging in a LoRA.
595
- has_key_with_lora_prefix = mod .has_keys_starting_with (
639
+ has_key_with_lora_prefix = has_keys_starting_with (
640
+ mod .load_state_dict (),
596
641
{
597
642
"lora_te_" ,
598
643
"lora_unet_" ,
599
644
"lora_te1_" ,
600
645
"lora_te2_" ,
601
646
"lora_transformer_" ,
602
- }
647
+ },
603
648
)
604
649
605
- has_key_with_lora_suffix = mod .has_keys_ending_with (
650
+ has_key_with_lora_suffix = has_keys_ending_with (
651
+ mod .load_state_dict (),
606
652
{
607
653
"to_k_lora.up.weight" ,
608
654
"to_q_lora.down.weight" ,
609
655
"lora_A.weight" ,
610
656
"lora_B.weight" ,
611
- }
657
+ },
612
658
)
613
659
614
660
if not has_key_with_lora_prefix and not has_key_with_lora_suffix :
@@ -754,7 +800,13 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
754
800
755
801
@classmethod
756
802
def _validate_looks_like_vae (cls , mod : ModelOnDisk ) -> None :
757
- if not mod .has_keys_starting_with ({"encoder.conv_in" , "decoder.conv_in" }):
803
+ if not has_keys_starting_with (
804
+ mod .load_state_dict (),
805
+ {
806
+ "encoder.conv_in" ,
807
+ "decoder.conv_in" ,
808
+ },
809
+ ):
758
810
raise NotAMatch (cls , "model does not match Checkpoint VAE heuristics" )
759
811
760
812
@classmethod
@@ -786,7 +838,14 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
786
838
787
839
_validate_override_fields (cls , fields )
788
840
789
- _validate_class_name (cls , mod .common_config_paths (), {"AutoencoderKL" , "AutoencoderTiny" })
841
+ _validate_class_name (
842
+ cls ,
843
+ common_config_paths (mod .path ),
844
+ {
845
+ "AutoencoderKL" ,
846
+ "AutoencoderTiny" ,
847
+ },
848
+ )
790
849
791
850
base = fields .get ("base" ) or cls ._get_base_or_raise (mod )
792
851
return cls (** fields , base = base )
@@ -812,7 +871,7 @@ def _guess_name(cls, mod: ModelOnDisk) -> str:
812
871
813
872
@classmethod
814
873
def _get_base_or_raise (cls , mod : ModelOnDisk ) -> VAEDiffusersConfig_SupportedBases :
815
- config = _get_config_or_raise (cls , mod . common_config_paths ())
874
+ config = _get_config_or_raise (cls , common_config_paths (mod . path ))
816
875
if cls ._config_looks_like_sdxl (config ):
817
876
return BaseModelType .StableDiffusionXL
818
877
elif cls ._name_looks_like_sdxl (mod ):
@@ -843,15 +902,22 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
843
902
844
903
_validate_override_fields (cls , fields )
845
904
846
- _validate_class_name (cls , mod .common_config_paths (), {"ControlNetModel" , "FluxControlNetModel" })
905
+ _validate_class_name (
906
+ cls ,
907
+ common_config_paths (mod .path ),
908
+ {
909
+ "ControlNetModel" ,
910
+ "FluxControlNetModel" ,
911
+ },
912
+ )
847
913
848
914
base = fields .get ("base" ) or cls ._get_base_or_raise (mod )
849
915
850
916
return cls (** fields , base = base )
851
917
852
918
@classmethod
853
919
def _get_base_or_raise (cls , mod : ModelOnDisk ) -> ControlNetDiffusers_SupportedBases :
854
- config = _get_config_or_raise (cls , mod . common_config_paths ())
920
+ config = _get_config_or_raise (cls , common_config_paths (mod . path ))
855
921
856
922
if config .get ("_class_name" ) == "FluxControlNetModel" :
857
923
return BaseModelType .Flux
@@ -900,7 +966,8 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
900
966
901
967
@classmethod
902
968
def _validate_looks_like_controlnet (cls , mod : ModelOnDisk ) -> None :
903
- if not mod .has_keys_starting_with (
969
+ if has_keys_starting_with (
970
+ mod .load_state_dict (),
904
971
{
905
972
"controlnet" ,
906
973
"control_model" ,
@@ -911,7 +978,7 @@ def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None:
911
978
# "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so
912
979
# delicate.
913
980
"controlnet_blocks" ,
914
- }
981
+ },
915
982
):
916
983
raise NotAMatch (cls , "state dict does not look like a ControlNet checkpoint" )
917
984
@@ -1268,7 +1335,8 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1268
1335
1269
1336
@classmethod
1270
1337
def _validate_is_flux (cls , mod : ModelOnDisk ) -> None :
1271
- if not mod .has_keys_exact (
1338
+ if not has_keys_exact (
1339
+ mod .load_state_dict (),
1272
1340
{
1273
1341
"double_blocks.0.img_attn.norm.key_norm.scale" ,
1274
1342
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" ,
@@ -1426,7 +1494,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1426
1494
1427
1495
_validate_class_name (
1428
1496
cls ,
1429
- mod . common_config_paths (),
1497
+ common_config_paths (mod . path ),
1430
1498
{
1431
1499
# SD 1.x and 2.x
1432
1500
"StableDiffusionPipeline" ,
@@ -1527,7 +1595,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1527
1595
1528
1596
_validate_class_name (
1529
1597
cls ,
1530
- mod . common_config_paths (),
1598
+ common_config_paths (mod . path ),
1531
1599
{
1532
1600
"StableDiffusion3Pipeline" ,
1533
1601
"SD3Transformer2DModel" ,
@@ -1548,7 +1616,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1548
1616
@classmethod
1549
1617
def _get_submodels_or_raise (cls , mod : ModelOnDisk ) -> dict [SubModelType , SubmodelDefinition ]:
1550
1618
# Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json
1551
- config = _get_config_or_raise (cls , mod . common_config_paths ())
1619
+ config = _get_config_or_raise (cls , common_config_paths (mod . path ))
1552
1620
1553
1621
submodels : dict [SubModelType , SubmodelDefinition ] = {}
1554
1622
@@ -1601,8 +1669,10 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1601
1669
1602
1670
_validate_class_name (
1603
1671
cls ,
1604
- mod .common_config_paths (),
1605
- {"CogView4Pipeline" },
1672
+ common_config_paths (mod .path ),
1673
+ {
1674
+ "CogView4Pipeline" ,
1675
+ },
1606
1676
)
1607
1677
1608
1678
repo_variant = fields .get ("repo_variant" ) or cls ._get_repo_variant_or_raise (mod )
@@ -1706,13 +1776,14 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1706
1776
1707
1777
@classmethod
1708
1778
def _validate_looks_like_ip_adapter (cls , mod : ModelOnDisk ) -> None :
1709
- if not mod .has_keys_starting_with (
1779
+ if not has_keys_starting_with (
1780
+ mod .load_state_dict (),
1710
1781
{
1711
1782
"image_proj." ,
1712
1783
"ip_adapter." ,
1713
1784
# XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.".
1714
1785
"ip_adapter_proj_model." ,
1715
- }
1786
+ },
1716
1787
):
1717
1788
raise NotAMatch (cls , "model does not match Checkpoint IP Adapter heuristics" )
1718
1789
@@ -1778,7 +1849,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1778
1849
1779
1850
_validate_class_name (
1780
1851
cls ,
1781
- mod . common_config_paths (),
1852
+ common_config_paths (mod . path ),
1782
1853
{
1783
1854
"CLIPModel" ,
1784
1855
"CLIPTextModel" ,
@@ -1792,7 +1863,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1792
1863
1793
1864
@classmethod
1794
1865
def _validate_clip_g_variant (cls , mod : ModelOnDisk ) -> None :
1795
- config = _get_config_or_raise (cls , mod . common_config_paths ())
1866
+ config = _get_config_or_raise (cls , common_config_paths (mod . path ))
1796
1867
clip_variant = _get_clip_variant_type_from_config (config )
1797
1868
1798
1869
if clip_variant is not ClipVariantType .G :
@@ -1816,7 +1887,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1816
1887
1817
1888
_validate_class_name (
1818
1889
cls ,
1819
- mod . common_config_paths (),
1890
+ common_config_paths (mod . path ),
1820
1891
{
1821
1892
"CLIPModel" ,
1822
1893
"CLIPTextModel" ,
@@ -1830,7 +1901,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1830
1901
1831
1902
@classmethod
1832
1903
def _validate_clip_l_variant (cls , mod : ModelOnDisk ) -> None :
1833
- config = _get_config_or_raise (cls , mod . common_config_paths ())
1904
+ config = _get_config_or_raise (cls , common_config_paths (mod . path ))
1834
1905
clip_variant = _get_clip_variant_type_from_config (config )
1835
1906
1836
1907
if clip_variant is not ClipVariantType .L :
@@ -1852,7 +1923,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1852
1923
1853
1924
_validate_class_name (
1854
1925
cls ,
1855
- mod . common_config_paths (),
1926
+ common_config_paths (mod . path ),
1856
1927
{
1857
1928
"CLIPVisionModelWithProjection" ,
1858
1929
},
@@ -1882,7 +1953,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1882
1953
1883
1954
_validate_class_name (
1884
1955
cls ,
1885
- mod . common_config_paths (),
1956
+ common_config_paths (mod . path ),
1886
1957
{
1887
1958
"T2IAdapter" ,
1888
1959
},
@@ -1894,7 +1965,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1894
1965
1895
1966
@classmethod
1896
1967
def _get_base_or_raise (cls , mod : ModelOnDisk ) -> T2IAdapterDiffusers_SupportedBases :
1897
- config = _get_config_or_raise (cls , mod . common_config_paths ())
1968
+ config = _get_config_or_raise (cls , common_config_paths (mod . path ))
1898
1969
1899
1970
adapter_type = config .get ("adapter_type" )
1900
1971
@@ -1955,7 +2026,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1955
2026
1956
2027
_validate_class_name (
1957
2028
cls ,
1958
- mod . common_config_paths (),
2029
+ common_config_paths (mod . path ),
1959
2030
{
1960
2031
"SiglipModel" ,
1961
2032
},
@@ -1998,7 +2069,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1998
2069
1999
2070
_validate_class_name (
2000
2071
cls ,
2001
- mod . common_config_paths (),
2072
+ common_config_paths (mod . path ),
2002
2073
{
2003
2074
"LlavaOnevisionForConditionalGeneration" ,
2004
2075
},
0 commit comments