Skip to content

Commit 7c70701

Browse files
feat(mm): add sanity checks before probing paths
1 parent d81a3d6 commit 7c70701

File tree

2 files changed

+147
-36
lines changed

2 files changed

+147
-36
lines changed

invokeai/backend/model_manager/configs/factory.py

Lines changed: 127 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,29 @@
109109
logger = logging.getLogger(__name__)
110110
app_config = get_config()
111111

112+
# Known model file extensions for sanity checking
113+
_MODEL_EXTENSIONS = {
114+
".safetensors",
115+
".ckpt",
116+
".pt",
117+
".pth",
118+
".bin",
119+
".gguf",
120+
".onnx",
121+
}
122+
123+
# Known config file names for diffusers/transformers models
124+
_CONFIG_FILES = {
125+
"model_index.json",
126+
"config.json",
127+
}
128+
129+
# Maximum number of files in a directory to be considered a model
130+
_MAX_FILES_IN_MODEL_DIR = 50
131+
132+
# Maximum depth to search for model files in directories
133+
_MAX_SEARCH_DEPTH = 2
134+
112135

113136
# The types are listed explicitly because IDEs/LSPs can't identify the correct types
114137
# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes
@@ -276,6 +299,68 @@ def build_common_fields(
276299

277300
return fields
278301

302+
@staticmethod
303+
def _validate_path_looks_like_model(path: Path) -> None:
304+
"""Perform basic sanity checks to ensure a path looks like a model.
305+
306+
This prevents wasting time trying to identify obviously non-model paths like
307+
home directories or downloads folders. Raises RuntimeError if the path doesn't
308+
pass basic checks.
309+
310+
Args:
311+
path: The path to validate
312+
313+
Raises:
314+
RuntimeError: If the path doesn't look like a model
315+
"""
316+
if path.is_file():
317+
# For files, just check the extension
318+
if path.suffix.lower() not in _MODEL_EXTENSIONS:
319+
raise RuntimeError(
320+
f"File extension {path.suffix} is not a recognized model format. "
321+
f"Expected one of: {', '.join(sorted(_MODEL_EXTENSIONS))}"
322+
)
323+
else:
324+
# For directories, do a quick file count check with early exit
325+
total_files = 0
326+
for item in path.rglob("*"):
327+
if item.is_file():
328+
total_files += 1
329+
if total_files > _MAX_FILES_IN_MODEL_DIR:
330+
raise RuntimeError(
331+
f"Directory contains more than {_MAX_FILES_IN_MODEL_DIR} files. "
332+
"This looks like a general-purpose directory rather than a model. "
333+
"Please provide a path to a specific model file or model directory."
334+
)
335+
336+
# Check if it has config files at root (diffusers/transformers marker)
337+
has_root_config = any((path / config).exists() for config in _CONFIG_FILES)
338+
339+
if has_root_config:
340+
# Has a config file, looks like a valid model directory
341+
return
342+
343+
# Otherwise, search for model files within depth limit
344+
def find_model_files(current_path: Path, depth: int) -> bool:
345+
if depth > _MAX_SEARCH_DEPTH:
346+
return False
347+
try:
348+
for item in current_path.iterdir():
349+
if item.is_file() and item.suffix.lower() in _MODEL_EXTENSIONS:
350+
return True
351+
elif item.is_dir() and find_model_files(item, depth + 1):
352+
return True
353+
except PermissionError:
354+
pass
355+
return False
356+
357+
if not find_model_files(path, 0):
358+
raise RuntimeError(
359+
f"No model files or config files found in directory {path}. "
360+
f"Expected to find model files with extensions: {', '.join(sorted(_MODEL_EXTENSIONS))} "
361+
f"or config files: {', '.join(sorted(_CONFIG_FILES))}"
362+
)
363+
279364
@staticmethod
280365
def from_model_on_disk(
281366
mod: str | Path | ModelOnDisk,
@@ -290,6 +375,10 @@ def from_model_on_disk(
290375
if isinstance(mod, Path | str):
291376
mod = ModelOnDisk(Path(mod), hash_algo)
292377

378+
# Perform basic sanity checks before attempting any config matching
379+
# This rejects obviously non-model paths early, saving time
380+
ModelConfigFactory._validate_path_looks_like_model(mod.path)
381+
293382
# We will always need these fields to build any model config.
294383
fields = ModelConfigFactory.build_common_fields(mod, override_fields)
295384

@@ -317,48 +406,53 @@ def from_model_on_disk(
317406
logger.warning(f"Schema validation error for {config_class.__name__} on model {mod.name}: {e}")
318407
except Exception as e:
319408
results[class_name] = e
320-
logger.warning(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}")
409+
logger.debug(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}")
321410

322411
matches = [r for r in results.values() if isinstance(r, Config_Base)]
323412

324-
if not matches and app_config.allow_unknown_models:
325-
logger.warning(f"Unable to identify model {mod.name}, falling back to Unknown_Config")
326-
return Unknown_Config(
327-
**fields,
328-
# Override the type/format/base to ensure it's marked as unknown.
329-
base=BaseModelType.Unknown,
330-
type=ModelType.Unknown,
331-
format=ModelFormat.Unknown,
332-
)
413+
if not matches:
414+
# No matches at all. This should be very rare, but just in case, we will fall back to Unknown_Config.
415+
msg = f"No model config matched for model {mod.path}"
416+
logger.error(msg)
417+
raise RuntimeError(msg)
418+
419+
# It is possible that we have multiple matches. We need to prioritize them.
420+
#
421+
# Known cases where multiple matches can occur:
422+
# - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model.
423+
# - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with
424+
# a config.json file. Prefer the main model.
425+
#
426+
# Given the above cases, we can prioritize the matches by type. If we find more cases, we may need a more
427+
# sophisticated approach.
428+
#
429+
# Unknown models should always be the last resort fallback.
430+
def sort_key(m: AnyModelConfig) -> int:
431+
match m.type:
432+
case ModelType.Main:
433+
return 0
434+
case ModelType.LoRA:
435+
return 1
436+
case ModelType.CLIPEmbed:
437+
return 2
438+
case ModelType.Unknown:
439+
# Unknown should always be tried last as a fallback
440+
return 999
441+
case _:
442+
return 3
443+
444+
matches.sort(key=sort_key)
333445

334446
if len(matches) > 1:
335-
# We have multiple matches, in which case at most 1 is correct. We need to pick one.
336-
#
337-
# Known cases:
338-
# - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model.
339-
# - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with
340-
# a config.json file. Prefer the main model.
341-
#
342-
# Given the above cases, we can prioritize the matches by type. If we find more cases, we may need a more
343-
# sophisticated approach.
344-
def sort_key(m: AnyModelConfig) -> int:
345-
match m.type:
346-
case ModelType.Main:
347-
return 0
348-
case ModelType.LoRA:
349-
return 1
350-
case ModelType.CLIPEmbed:
351-
return 2
352-
case _:
353-
return 3
354-
355-
matches.sort(key=sort_key)
356447
logger.warning(
357-
f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}."
448+
f"Multiple model config classes matched for model {mod.path}: {[type(m).__name__ for m in matches]}."
358449
)
359450

360451
instance = matches[0]
361-
logger.info(f"Model {mod.name} classified as {type(instance).__name__}")
452+
if isinstance(instance, Unknown_Config):
453+
logger.warning(f"Unable to identify model {mod.path}, falling back to Unknown_Config")
454+
else:
455+
logger.info(f"Model {mod.path} classified as {type(instance).__name__}")
362456

363457
# Now do any post-processing needed for specific model types/bases/etc.
364458
match instance.type:

invokeai/backend/model_manager/configs/unknown.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pydantic import Field
44

5+
from invokeai.app.services.config.config_default import get_config
56
from invokeai.backend.model_manager.configs.base import Config_Base
67
from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError
78
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
@@ -11,14 +12,30 @@
1112
ModelType,
1213
)
1314

15+
app_config = get_config()
16+
1417

1518
class Unknown_Config(Config_Base):
16-
"""Model config for unknown models, used as a fallback when we cannot identify a model."""
19+
"""Model config for unknown models, used as a fallback when we cannot positively identify a model."""
1720

1821
base: Literal[BaseModelType.Unknown] = Field(default=BaseModelType.Unknown)
1922
type: Literal[ModelType.Unknown] = Field(default=ModelType.Unknown)
2023
format: Literal[ModelFormat.Unknown] = Field(default=ModelFormat.Unknown)
2124

2225
@classmethod
23-
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
24-
raise NotAMatchError("unknown model config cannot match any model")
26+
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
27+
"""Create an Unknown_Config for models that couldn't be positively identified.
28+
29+
Note: Basic path validation (file extensions, directory structure) is already
30+
performed by ModelConfigFactory before this method is called.
31+
"""
32+
if not app_config.allow_unknown_models:
33+
raise NotAMatchError("unknown models are not allowed by configuration")
34+
35+
return cls(
36+
**override_fields,
37+
# Override the type/format/base to ensure it's marked as unknown.
38+
base=BaseModelType.Unknown,
39+
type=ModelType.Unknown,
40+
format=ModelFormat.Unknown,
41+
)

0 commit comments

Comments
 (0)