Skip to content

Commit 7765c83

Browse files
feat(mm): wip port of main models to new api
1 parent 5f45a9c commit 7765c83

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,7 +1164,9 @@ def _get_base_or_raise(cls, mod: ModelOnDisk) -> MainDiffusers_SupportedBases:
11641164
raise NotAMatch(cls, "unable to determine base type")
11651165

11661166
@classmethod
1167-
def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> SchedulerPredictionType:
1167+
def _get_scheduler_prediction_type_or_raise(
1168+
cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases
1169+
) -> SchedulerPredictionType:
11681170
if base not in {
11691171
BaseModelType.StableDiffusion1,
11701172
BaseModelType.StableDiffusion2,
@@ -1186,7 +1188,7 @@ def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk, base: BaseMod
11861188
raise NotAMatch(cls, f"unrecognized scheduler prediction type {prediction_type}")
11871189

11881190
@classmethod
1189-
def _get_variant_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> ModelVariantType:
1191+
def _get_variant_or_raise(cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases) -> ModelVariantType:
11901192
if base not in {
11911193
BaseModelType.StableDiffusion1,
11921194
BaseModelType.StableDiffusion2,
@@ -1197,20 +1199,29 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> ModelVa
11971199
unet_config = _get_config_or_raise(cls, mod.path / "unet" / "config.json")
11981200
in_channels = unet_config.get("in_channels")
11991201

1200-
match in_channels:
1201-
case 4:
1202-
return ModelVariantType.Normal
1203-
case 5:
1204-
if base is not BaseModelType.StableDiffusion2:
1205-
raise NotAMatch(cls, "in_channels=5 is only valid for Stable Diffusion 2 models")
1206-
return ModelVariantType.Depth
1207-
case 9:
1208-
return ModelVariantType.Inpaint
1209-
case _:
1210-
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels}")
1202+
if base is BaseModelType.StableDiffusion2:
1203+
match in_channels:
1204+
case 4:
1205+
return ModelVariantType.Normal
1206+
case 9:
1207+
return ModelVariantType.Inpaint
1208+
case 5:
1209+
return ModelVariantType.Depth
1210+
case _:
1211+
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
1212+
else:
1213+
match in_channels:
1214+
case 4:
1215+
return ModelVariantType.Normal
1216+
case 9:
1217+
return ModelVariantType.Inpaint
1218+
case _:
1219+
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'")
12111220

12121221
@classmethod
1213-
def _get_submodels_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> dict[SubModelType, SubmodelDefinition]:
1222+
def _get_submodels_or_raise(
1223+
cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases
1224+
) -> dict[SubModelType, SubmodelDefinition]:
12141225
if base is not BaseModelType.StableDiffusion3:
12151226
raise ValueError(f"Attempted to get submodels for non-SD3 model base '{base}'")
12161227

0 commit comments

Comments
 (0)