Skip to content

Commit a129424

Browse files
feat(mm): add model config schema migration logic
1 parent bdf3474 commit a129424

File tree

1 file changed

+69
-24
lines changed

1 file changed

+69
-24
lines changed

invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py

Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
from invokeai.app.services.config import InvokeAIAppConfig
1010
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
11-
from invokeai.backend.model_manager.configs.factory import AnyModelConfigValidator
11+
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, AnyModelConfigValidator
12+
from invokeai.backend.model_manager.taxonomy import BaseModelType, FluxVariantType, ModelType, SchedulerPredictionType
1213

1314

1415
class NormalizeResult(NamedTuple):
@@ -29,9 +30,8 @@ def __call__(self, cursor: sqlite3.Cursor) -> None:
2930

3031
for model_id, config_json in rows:
3132
try:
32-
migrated_config_dict = self._migrate_config(config_json)
33-
# Get the model config as a pydantic object
34-
config = AnyModelConfigValidator.validate_python(migrated_config_dict)
33+
# Migrate the config JSON to the latest schema
34+
config = self._parse_and_migrate_config(config_json)
3535
except ValidationError:
3636
# This could happen if the config schema changed in a way that makes old configs invalid. Unlikely
3737
# for users, more likely for devs testing out migration paths.
@@ -71,31 +71,76 @@ def __call__(self, cursor: sqlite3.Cursor) -> None:
7171
cursor.execute("ROLLBACK TO SAVEPOINT migrate_model")
7272
cursor.execute("RELEASE SAVEPOINT migrate_model")
7373
self._rollback_file_ops(rollback_ops)
74-
continue
74+
raise
7575

7676
cursor.execute("RELEASE SAVEPOINT migrate_model")
7777

7878
self._prune_empty_directories()
7979

80-
def _migrate_config(self, config_json: Any) -> str | None:
81-
config_dict = json.loads(config_json)
82-
83-
# TODO: migrate fields, review changes to ensure we hit all cases for v6.7.0 to v6.8.0 upgrade.
84-
85-
# Prior to v6.8.0, we used an awkward combination of `config_path` and `variant` to distinguish between FLUX
86-
# variants.
87-
#
88-
# `config_path` was set to one of:
89-
# - flux-dev
90-
# - flux-dev-fill
91-
# - flux-schnell
92-
#
93-
# `variant` was set to ModelVariantType.Inpaint for FLUX Fill models and ModelVariantType.Normal for all other FLUX
94-
# models.
95-
#
96-
# We now use the `variant` field to directly represent the FLUX variant type, and `config_path` is no longer used.
97-
98-
return config_dict
80+
def _parse_and_migrate_config(self, config_json: Any) -> AnyModelConfig:
81+
config_dict: dict[str, Any] = json.loads(config_json)
82+
83+
# In v6.8.0 we made some improvements to the model taxonomy and the model config schemas. There are a changes
84+
# we need to make to old configs to bring them up to date.
85+
86+
base = config_dict.get("base")
87+
type = config_dict.get("type")
88+
if base == BaseModelType.Flux.value and type == ModelType.Main.value:
89+
# Prior to v6.8.0, we used an awkward combination of `config_path` and `variant` to distinguish between FLUX
90+
# variants.
91+
#
92+
# `config_path` was set to one of:
93+
# - flux-dev
94+
# - flux-dev-fill
95+
# - flux-schnell
96+
#
97+
# `variant` was set to ModelVariantType.Inpaint for FLUX Fill models and ModelVariantType.Normal for all other FLUX
98+
# models.
99+
#
100+
# We now use the `variant` field to directly represent the FLUX variant type, and `config_path` is no longer used.
101+
102+
# Extract and remove `config_path` if present.
103+
config_path = config_dict.pop("config_path", None)
104+
105+
match config_path:
106+
case "flux-dev":
107+
config_dict["variant"] = FluxVariantType.Dev.value
108+
case "flux-dev-fill":
109+
config_dict["variant"] = FluxVariantType.DevFill.value
110+
case "flux-schnell":
111+
config_dict["variant"] = FluxVariantType.Schnell.value
112+
case _:
113+
# Unknown config_path - default to Dev variant
114+
config_dict["variant"] = FluxVariantType.Dev.value
115+
116+
if (
117+
base
118+
in {
119+
BaseModelType.StableDiffusion1.value,
120+
BaseModelType.StableDiffusion2.value,
121+
BaseModelType.StableDiffusionXL.value,
122+
BaseModelType.StableDiffusionXLRefiner.value,
123+
}
124+
and type == "main"
125+
):
126+
# Prior to v6.8.0, the prediction_type field was optional and would default to Epsilon if not present.
127+
# We now make it explicit and always present. Use the existing value if present, otherwise default to
128+
# Epsilon, matching the probe logic.
129+
#
130+
# It's only on SD1.x, SD2.x, and SDXL main models.
131+
config_dict["prediction_type"] = config_dict.get("prediction_type", SchedulerPredictionType.Epsilon.value)
132+
133+
if type == ModelType.CLIPVision.value:
134+
# Prior to v6.8.0, some CLIP Vision models were associated with a specific base model architecture:
135+
# - CLIP-ViT-bigG-14-laion2B-39B-b160k is the image encoder for SDXL IP Adapter and was associated with SDXL
136+
# - CLIP-ViT-H-14-laion2B-s32B-b79K is the image encoder for SD1.5 IP Adapter and was associated with SD1.5
137+
#
138+
# While this made some sense at the time, it is more correct and flexible to treat CLIP Vision models
139+
# as independent of any specific base model architecture.
140+
config_dict["base"] = BaseModelType.Any.value
141+
142+
migrated_config = AnyModelConfigValidator.validate_python(config_dict)
143+
return migrated_config
99144

100145
def _normalize_model_storage(self, key: str, path_value: str) -> NormalizeResult:
101146
models_dir = self._models_dir

0 commit comments

Comments
 (0)