Skip to content

Commit 881f063

Browse files
refactor(mm): simplify model classification process
Previously, we had a multi-phase strategy to identify models from their files on disk: 1. Run each model config classes' `matches()` method on the files. It checks if the model could possibly be an identified as the candidate model type. This was intended to be a quick check. Break on the first match. 2. If we have a match, run the config class's `parse()` method. It derive some additional model config attrs from the model files. This was intended to encapsulate heavier operations that may require loading the model into memory. 3. Derive the common model config attrs, like name, description, calculate the hash, etc. Some of these are also heavier operations. This strategy has some issues: - It is not clear how the pieces fit together. There is some back-and-forth between different methods and the config base class. It is hard to trace the flow of logic until you fully wrap your head around the system and therefore difficult to add a model architecture to the probe. - The assumption that we could do quick, lightweight checks before heavier checks is incorrect. We often _must_ load the model state dict in the `matches()` method. So there is no practical perf benefit to splitting up the responsibility of `matches()` and `parse()`. - Sometimes we need to do the same checks in `matches()` and `parse()`. In these cases, splitting the logic is has a negative perf impact because we are doing the same work twice. - As we introduce the concept of an "unknown" model config (i.e. a model that we cannot identify, but still record in the db; see #8582), we will _always_ run _all_ the checks for every model. Therefore we need not try to defer heavier checks or resource-intensive ops like hashing. We are going to do them anyways. - There are situations where a model may match multiple configs. One known case are SD pipeline models with merged LoRAs. In the old probe API, we relied on the implicit order of checks to know that if a model matched for pipeline _and_ LoRA, we prefer the pipeline match. But, in the new API, we do not have this implicit ordering of checks. To resolve this in a resilient way, we need to get all matches up front, then use tie-breaker logic to figure out which should win (or add "differential diagnosis" logic to the matchers). - Field overrides weren't handled well by this strategy. They were only applied at the very end, if a model matched successfully. This means we cannot tell the system "Hey, this model is type X with base Y. Trust me bro.". We cannot override the match logic. As we move towards letting users correct mis-identified models (see #8582), this is a requirement. We can simplify the process significantly and better support "unknown" models. Firstly, model config classes now have a single `from_model_on_disk()` method that attempts to construct an instance of the class from the model files. This replaces the `matches()` and `parse()` methods. If we fail to create the config instance, a special exception is raised that indicates why we think the files cannot be identified as the given model config class. Next, the flow for model identification is a bit simpler: - Derive all the common fields up-front (name, desc, hash, etc). - Merge in overrides. - Call `from_model_on_disk()` for every config class, passing in the fields. Overrides are handled in this method. - Record the results for each config class and choose the best one. The identification logic is a bit more verbose, with the special exceptions and handling of overrides, but it is very clear what is happening. The one downside I can think of for this strategy is we do need to check every model type, instead of stopping at the first match. It's a bit less efficient. In practice, however, this isn't a hot code path, and the improved clarity is worth far more than perf optimizations that the end user will likely never notice.
1 parent 1d3f6c4 commit 881f063

File tree

2 files changed

+663
-6
lines changed

2 files changed

+663
-6
lines changed

invokeai/app/services/model_install/model_install_default.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None):
612612
try:
613613
return ModelProbe.probe(model_path=model_path, fields=deepcopy(fields), hash_algo=hash_algo) # type: ignore
614614
except InvalidModelConfigException:
615-
return ModelConfigBase.classify(mod=model_path, hash_algo=hash_algo, **fields)
615+
return ModelConfigBase.classify(mod=model_path, fields=deepcopy(fields), hash_algo=hash_algo)
616616

617617
def _register(
618618
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None

0 commit comments

Comments
 (0)