Skip to content

Commit 09449cf

Browse files
tidy(mm): clean up model heuristic utils
1 parent 9676cb8 commit 09449cf

File tree

2 files changed

+124
-76
lines changed

2 files changed

+124
-76
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 124 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -106,35 +106,67 @@ def __init__(
106106
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
107107

108108

109-
# Utility from https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144
110-
def find_field_schema(model: type[BaseModel], field_name: str) -> CoreSchema:
111-
schema: CoreSchema = model.__pydantic_core_schema__.copy()
112-
# we shallow copied, be careful not to mutate the original schema!
109+
class FieldValidator:
110+
"""Utility class for validating individual fields of a Pydantic model without instantiating the whole model.
113111
114-
assert schema["type"] in ["definitions", "model"]
112+
See: https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144
113+
"""
115114

116-
# find the field schema
117-
field_schema = schema["schema"] # type: ignore
118-
while "fields" not in field_schema:
119-
field_schema = field_schema["schema"] # type: ignore
115+
@staticmethod
116+
def find_field_schema(model: type[BaseModel], field_name: str) -> CoreSchema:
117+
"""Find the Pydantic core schema for a specific field in a model."""
118+
schema: CoreSchema = model.__pydantic_core_schema__.copy()
119+
# we shallow copied, be careful not to mutate the original schema!
120120

121-
field_schema = field_schema["fields"][field_name]["schema"] # type: ignore
121+
assert schema["type"] in ["definitions", "model"]
122+
123+
# find the field schema
124+
field_schema = schema["schema"] # type: ignore
125+
while "fields" not in field_schema:
126+
field_schema = field_schema["schema"] # type: ignore
127+
128+
field_schema = field_schema["fields"][field_name]["schema"] # type: ignore
129+
130+
# if the original schema is a definition schema, replace the model schema with the field schema
131+
if schema["type"] == "definitions":
132+
schema["schema"] = field_schema
133+
return schema
134+
else:
135+
return field_schema
136+
137+
@cache
138+
@staticmethod
139+
def get_validator(model: type[BaseModel], field_name: str) -> SchemaValidator:
140+
"""Get a SchemaValidator for a specific field in a model."""
141+
return SchemaValidator(FieldValidator.find_field_schema(model, field_name))
142+
143+
@staticmethod
144+
def validate_field(model: type[BaseModel], field_name: str, value: Any) -> Any:
145+
"""Validate a value for a specific field in a model."""
146+
return FieldValidator.get_validator(model, field_name).validate_python(value)
122147

123-
# if the original schema is a definition schema, replace the model schema with the field schema
124-
if schema["type"] == "definitions":
125-
schema["schema"] = field_schema
126-
return schema
127-
else:
128-
return field_schema
129148

149+
def has_keys_exact(state_dict: dict[str | int, Any], keys: str | set[str]) -> bool:
150+
"""Returns true if the state dict has all of the specified keys."""
151+
_keys = {keys} if isinstance(keys, str) else keys
152+
return _keys.issubset({key for key in state_dict.keys() if isinstance(key, str)})
130153

131-
@cache
132-
def validator(model: type[BaseModel], field_name: str) -> SchemaValidator:
133-
return SchemaValidator(find_field_schema(model, field_name))
134154

155+
def has_keys_starting_with(state_dict: dict[str | int, Any], prefixes: str | set[str]) -> bool:
156+
"""Returns true if the state dict has any keys starting with any of the specified prefixes."""
157+
_prefixes = {prefixes} if isinstance(prefixes, str) else prefixes
158+
return any(any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str))
135159

136-
def validate_model_field(model: type[BaseModel], field_name: str, value: Any) -> Any:
137-
return validator(model, field_name).validate_python(value)
160+
161+
def has_keys_ending_with(state_dict: dict[str | int, Any], suffixes: str | set[str]) -> bool:
162+
"""Returns true if the state dict has any keys ending with any of the specified suffixes."""
163+
_suffixes = {suffixes} if isinstance(suffixes, str) else suffixes
164+
return any(any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str))
165+
166+
167+
def common_config_paths(path: Path) -> set[Path]:
168+
"""Returns common config file paths for models stored in directories."""
169+
return {path / "config.json", path / "model_index.json"}
138170

139171

140172
# These utility functions are tightly coupled to the config classes below in order to make the process of raising
@@ -225,7 +257,7 @@ def _validate_override_fields(
225257
if field_name not in config_class.model_fields:
226258
raise NotAMatch(config_class, f"unknown override field: {field_name}")
227259
try:
228-
validate_model_field(config_class, field_name, override_value)
260+
FieldValidator.validate_field(config_class, field_name, override_value)
229261
except ValidationError as e:
230262
raise NotAMatch(config_class, f"invalid override for field '{field_name}': {e}") from e
231263

@@ -440,7 +472,13 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
440472

441473
_validate_override_fields(cls, fields)
442474

443-
_validate_class_name(cls, mod.common_config_paths(), {"T5EncoderModel"})
475+
_validate_class_name(
476+
cls,
477+
common_config_paths(mod.path),
478+
{
479+
"T5EncoderModel",
480+
},
481+
)
444482

445483
cls._validate_has_unquantized_config_file(mod)
446484

@@ -465,7 +503,13 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
465503

466504
_validate_override_fields(cls, fields)
467505

468-
_validate_class_name(cls, mod.common_config_paths(), {"T5EncoderModel"})
506+
_validate_class_name(
507+
cls,
508+
common_config_paths(mod.path),
509+
{
510+
"T5EncoderModel",
511+
},
512+
)
469513

470514
cls._validate_filename_looks_like_bnb_quantized(mod)
471515

@@ -481,7 +525,7 @@ def _validate_filename_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
481525

482526
@classmethod
483527
def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
484-
has_scb_key_suffix = mod.has_keys_ending_with("SCB")
528+
has_scb_key_suffix = has_keys_ending_with(mod.load_state_dict(), "SCB")
485529
if not has_scb_key_suffix:
486530
raise NotAMatch(cls, "state dict does not look like bnb quantized llm_int8")
487531

@@ -592,23 +636,25 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
592636
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
593637
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
594638
# Some main models have these keys, likely due to the creator merging in a LoRA.
595-
has_key_with_lora_prefix = mod.has_keys_starting_with(
639+
has_key_with_lora_prefix = has_keys_starting_with(
640+
mod.load_state_dict(),
596641
{
597642
"lora_te_",
598643
"lora_unet_",
599644
"lora_te1_",
600645
"lora_te2_",
601646
"lora_transformer_",
602-
}
647+
},
603648
)
604649

605-
has_key_with_lora_suffix = mod.has_keys_ending_with(
650+
has_key_with_lora_suffix = has_keys_ending_with(
651+
mod.load_state_dict(),
606652
{
607653
"to_k_lora.up.weight",
608654
"to_q_lora.down.weight",
609655
"lora_A.weight",
610656
"lora_B.weight",
611-
}
657+
},
612658
)
613659

614660
if not has_key_with_lora_prefix and not has_key_with_lora_suffix:
@@ -754,7 +800,13 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
754800

755801
@classmethod
756802
def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
757-
if not mod.has_keys_starting_with({"encoder.conv_in", "decoder.conv_in"}):
803+
if not has_keys_starting_with(
804+
mod.load_state_dict(),
805+
{
806+
"encoder.conv_in",
807+
"decoder.conv_in",
808+
},
809+
):
758810
raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics")
759811

760812
@classmethod
@@ -786,7 +838,14 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
786838

787839
_validate_override_fields(cls, fields)
788840

789-
_validate_class_name(cls, mod.common_config_paths(), {"AutoencoderKL", "AutoencoderTiny"})
841+
_validate_class_name(
842+
cls,
843+
common_config_paths(mod.path),
844+
{
845+
"AutoencoderKL",
846+
"AutoencoderTiny",
847+
},
848+
)
790849

791850
base = fields.get("base") or cls._get_base_or_raise(mod)
792851
return cls(**fields, base=base)
@@ -812,7 +871,7 @@ def _guess_name(cls, mod: ModelOnDisk) -> str:
812871

813872
@classmethod
814873
def _get_base_or_raise(cls, mod: ModelOnDisk) -> VAEDiffusersConfig_SupportedBases:
815-
config = _get_config_or_raise(cls, mod.common_config_paths())
874+
config = _get_config_or_raise(cls, common_config_paths(mod.path))
816875
if cls._config_looks_like_sdxl(config):
817876
return BaseModelType.StableDiffusionXL
818877
elif cls._name_looks_like_sdxl(mod):
@@ -843,15 +902,22 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
843902

844903
_validate_override_fields(cls, fields)
845904

846-
_validate_class_name(cls, mod.common_config_paths(), {"ControlNetModel", "FluxControlNetModel"})
905+
_validate_class_name(
906+
cls,
907+
common_config_paths(mod.path),
908+
{
909+
"ControlNetModel",
910+
"FluxControlNetModel",
911+
},
912+
)
847913

848914
base = fields.get("base") or cls._get_base_or_raise(mod)
849915

850916
return cls(**fields, base=base)
851917

852918
@classmethod
853919
def _get_base_or_raise(cls, mod: ModelOnDisk) -> ControlNetDiffusers_SupportedBases:
854-
config = _get_config_or_raise(cls, mod.common_config_paths())
920+
config = _get_config_or_raise(cls, common_config_paths(mod.path))
855921

856922
if config.get("_class_name") == "FluxControlNetModel":
857923
return BaseModelType.Flux
@@ -900,7 +966,8 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
900966

901967
@classmethod
902968
def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None:
903-
if not mod.has_keys_starting_with(
969+
if has_keys_starting_with(
970+
mod.load_state_dict(),
904971
{
905972
"controlnet",
906973
"control_model",
@@ -911,7 +978,7 @@ def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None:
911978
# "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so
912979
# delicate.
913980
"controlnet_blocks",
914-
}
981+
},
915982
):
916983
raise NotAMatch(cls, "state dict does not look like a ControlNet checkpoint")
917984

@@ -1268,7 +1335,8 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
12681335

12691336
@classmethod
12701337
def _validate_is_flux(cls, mod: ModelOnDisk) -> None:
1271-
if not mod.has_keys_exact(
1338+
if not has_keys_exact(
1339+
mod.load_state_dict(),
12721340
{
12731341
"double_blocks.0.img_attn.norm.key_norm.scale",
12741342
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
@@ -1426,7 +1494,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
14261494

14271495
_validate_class_name(
14281496
cls,
1429-
mod.common_config_paths(),
1497+
common_config_paths(mod.path),
14301498
{
14311499
# SD 1.x and 2.x
14321500
"StableDiffusionPipeline",
@@ -1527,7 +1595,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
15271595

15281596
_validate_class_name(
15291597
cls,
1530-
mod.common_config_paths(),
1598+
common_config_paths(mod.path),
15311599
{
15321600
"StableDiffusion3Pipeline",
15331601
"SD3Transformer2DModel",
@@ -1548,7 +1616,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
15481616
@classmethod
15491617
def _get_submodels_or_raise(cls, mod: ModelOnDisk) -> dict[SubModelType, SubmodelDefinition]:
15501618
# Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json
1551-
config = _get_config_or_raise(cls, mod.common_config_paths())
1619+
config = _get_config_or_raise(cls, common_config_paths(mod.path))
15521620

15531621
submodels: dict[SubModelType, SubmodelDefinition] = {}
15541622

@@ -1601,8 +1669,10 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
16011669

16021670
_validate_class_name(
16031671
cls,
1604-
mod.common_config_paths(),
1605-
{"CogView4Pipeline"},
1672+
common_config_paths(mod.path),
1673+
{
1674+
"CogView4Pipeline",
1675+
},
16061676
)
16071677

16081678
repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
@@ -1706,13 +1776,14 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
17061776

17071777
@classmethod
17081778
def _validate_looks_like_ip_adapter(cls, mod: ModelOnDisk) -> None:
1709-
if not mod.has_keys_starting_with(
1779+
if not has_keys_starting_with(
1780+
mod.load_state_dict(),
17101781
{
17111782
"image_proj.",
17121783
"ip_adapter.",
17131784
# XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.".
17141785
"ip_adapter_proj_model.",
1715-
}
1786+
},
17161787
):
17171788
raise NotAMatch(cls, "model does not match Checkpoint IP Adapter heuristics")
17181789

@@ -1778,7 +1849,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
17781849

17791850
_validate_class_name(
17801851
cls,
1781-
mod.common_config_paths(),
1852+
common_config_paths(mod.path),
17821853
{
17831854
"CLIPModel",
17841855
"CLIPTextModel",
@@ -1792,7 +1863,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
17921863

17931864
@classmethod
17941865
def _validate_clip_g_variant(cls, mod: ModelOnDisk) -> None:
1795-
config = _get_config_or_raise(cls, mod.common_config_paths())
1866+
config = _get_config_or_raise(cls, common_config_paths(mod.path))
17961867
clip_variant = _get_clip_variant_type_from_config(config)
17971868

17981869
if clip_variant is not ClipVariantType.G:
@@ -1816,7 +1887,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
18161887

18171888
_validate_class_name(
18181889
cls,
1819-
mod.common_config_paths(),
1890+
common_config_paths(mod.path),
18201891
{
18211892
"CLIPModel",
18221893
"CLIPTextModel",
@@ -1830,7 +1901,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
18301901

18311902
@classmethod
18321903
def _validate_clip_l_variant(cls, mod: ModelOnDisk) -> None:
1833-
config = _get_config_or_raise(cls, mod.common_config_paths())
1904+
config = _get_config_or_raise(cls, common_config_paths(mod.path))
18341905
clip_variant = _get_clip_variant_type_from_config(config)
18351906

18361907
if clip_variant is not ClipVariantType.L:
@@ -1852,7 +1923,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
18521923

18531924
_validate_class_name(
18541925
cls,
1855-
mod.common_config_paths(),
1926+
common_config_paths(mod.path),
18561927
{
18571928
"CLIPVisionModelWithProjection",
18581929
},
@@ -1882,7 +1953,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
18821953

18831954
_validate_class_name(
18841955
cls,
1885-
mod.common_config_paths(),
1956+
common_config_paths(mod.path),
18861957
{
18871958
"T2IAdapter",
18881959
},
@@ -1894,7 +1965,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
18941965

18951966
@classmethod
18961967
def _get_base_or_raise(cls, mod: ModelOnDisk) -> T2IAdapterDiffusers_SupportedBases:
1897-
config = _get_config_or_raise(cls, mod.common_config_paths())
1968+
config = _get_config_or_raise(cls, common_config_paths(mod.path))
18981969

18991970
adapter_type = config.get("adapter_type")
19001971

@@ -1955,7 +2026,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
19552026

19562027
_validate_class_name(
19572028
cls,
1958-
mod.common_config_paths(),
2029+
common_config_paths(mod.path),
19592030
{
19602031
"SiglipModel",
19612032
},
@@ -1998,7 +2069,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
19982069

19992070
_validate_class_name(
20002071
cls,
2001-
mod.common_config_paths(),
2072+
common_config_paths(mod.path),
20022073
{
20032074
"LlavaOnevisionForConditionalGeneration",
20042075
},

0 commit comments

Comments
 (0)