Skip to content

Commit 1ab20f4

Browse files
committed
Tidy spandrel model probe logic, and document the reasons behind the current implementation.
1 parent 9328c17 commit 1ab20f4

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

invokeai/backend/model_manager/probe.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Dict, Literal, Optional, Union
55

66
import safetensors.torch
7+
import spandrel
78
import torch
89
from picklescan.scanner import scan_file_path
910

@@ -242,15 +243,19 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[C
242243
return ModelType.TextualInversion
243244

244245
# Check if the model can be loaded as a SpandrelImageToImageModel.
246+
# This check is intentionally performed last, as it can be expensive (it requires loading the model from disk).
245247
try:
246-
# TODO(ryand): Figure out why load_from_state_dict() doesn't work as expected.
247-
# _ = SpandrelImageToImageModel.load_from_state_dict(ckpt)
248+
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
249+
# explored to avoid this:
250+
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
251+
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
252+
# supported on meta tensors.
253+
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
254+
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
255+
# maintain it, and the risk of false positive detections is higher.
248256
_ = SpandrelImageToImageModel.load_from_file(model_path)
249257
return ModelType.SpandrelImageToImage
250-
except Exception as e:
251-
# TODO(ryand): Catch a more specific exception type here if we can.
252-
# TODO(ryand): Delete this print statement.
253-
print(e)
258+
except spandrel.UnsupportedModelError:
254259
pass
255260

256261
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")

0 commit comments

Comments
 (0)