|
4 | 4 | from typing import Any, Dict, Literal, Optional, Union
|
5 | 5 |
|
6 | 6 | import safetensors.torch
|
| 7 | +import spandrel |
7 | 8 | import torch
|
8 | 9 | from picklescan.scanner import scan_file_path
|
9 | 10 |
|
@@ -242,15 +243,19 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[C
|
242 | 243 | return ModelType.TextualInversion
|
243 | 244 |
|
244 | 245 | # Check if the model can be loaded as a SpandrelImageToImageModel.
|
| 246 | + # This check is intentionally performed last, as it can be expensive (it requires loading the model from disk). |
245 | 247 | try:
|
246 |
| - # TODO(ryand): Figure out why load_from_state_dict() doesn't work as expected. |
247 |
| - # _ = SpandrelImageToImageModel.load_from_state_dict(ckpt) |
| 248 | + # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were |
| 249 | + # explored to avoid this: |
| 250 | + # 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta |
| 251 | + # device. Unfortunately, some Spandrel models perform operations during initialization that are not |
| 252 | + # supported on meta tensors. |
| 253 | + # 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model. |
| 254 | + # This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to |
| 255 | + # maintain it, and the risk of false positive detections is higher. |
248 | 256 | _ = SpandrelImageToImageModel.load_from_file(model_path)
|
249 | 257 | return ModelType.SpandrelImageToImage
|
250 |
| - except Exception as e: |
251 |
| - # TODO(ryand): Catch a more specific exception type here if we can. |
252 |
| - # TODO(ryand): Delete this print statement. |
253 |
| - print(e) |
| 258 | + except spandrel.UnsupportedModelError: |
254 | 259 | pass
|
255 | 260 |
|
256 | 261 | raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
|
|
0 commit comments