Skip to content

Commit 3a44fde

Browse files
feat(mm): wip port of main models to new api
1 parent 7765c83 commit 3a44fde

File tree

1 file changed

+36
-4
lines changed

1 file changed

+36
-4
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,12 +1126,29 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
11261126
)
11271127

11281128
base = fields.get("base") or cls._get_base_or_raise(mod)
1129+
if base in {
1130+
BaseModelType.StableDiffusion1,
1131+
BaseModelType.StableDiffusion2,
1132+
BaseModelType.StableDiffusionXL,
1133+
}:
1134+
variant = fields.get("variant") or cls._get_variant_or_raise(mod, base)
1135+
prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base)
1136+
upcast_attention = fields.get("upcast_attention") or cls._get_upcast_attention_or_raise(base, prediction_type)
1137+
else:
1138+
variant= None
1139+
prediction_type = None
1140+
upcast_attention = False
11291141

1130-
return cls(**fields, base=base)
1142+
if base is BaseModelType.StableDiffusion3:
1143+
submodels = fields.get("submodels") or cls._get_submodels_or_raise(mod, base)
1144+
else:
1145+
submodels = None
1146+
1147+
return cls(**fields, base=base, variant=variant, prediction_type=prediction_type, upcast_attention=upcast_attention, submodels=submodels,)
11311148

11321149
@classmethod
11331150
def _get_base_or_raise(cls, mod: ModelOnDisk) -> MainDiffusers_SupportedBases:
1134-
# Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
1151+
# Handle pipelines with a UNet (i.e SD 1.x, SD2.x, SDXL).
11351152
unet_config_path = mod.path / "unet" / "config.json"
11361153
if unet_config_path.exists():
11371154
with open(unet_config_path) as file:
@@ -1172,7 +1189,7 @@ def _get_scheduler_prediction_type_or_raise(
11721189
BaseModelType.StableDiffusion2,
11731190
BaseModelType.StableDiffusionXL,
11741191
}:
1175-
raise ValueError(f"Attempted to get scheduler prediction type for non-UNet model base '{base}'")
1192+
raise ValueError(f"Attempted to get scheduler prediction_type for non-UNet model base '{base}'")
11761193

11771194
scheduler_conf = _get_config_or_raise(cls, mod.path / "scheduler" / "scheduler_config.json")
11781195

@@ -1185,7 +1202,7 @@ def _get_scheduler_prediction_type_or_raise(
11851202
case "epsilon":
11861203
return SchedulerPredictionType.Epsilon
11871204
case _:
1188-
raise NotAMatch(cls, f"unrecognized scheduler prediction type {prediction_type}")
1205+
raise NotAMatch(cls, f"unrecognized scheduler prediction_type {prediction_type}")
11891206

11901207
@classmethod
11911208
def _get_variant_or_raise(cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases) -> ModelVariantType:
@@ -1266,6 +1283,21 @@ def _get_submodels_or_raise(
12661283
return submodels
12671284

12681285

1286+
@classmethod
1287+
def _get_upcast_attention_or_raise(cls, base: MainDiffusers_SupportedBases, prediction_type: SchedulerPredictionType) -> bool:
1288+
if base not in {
1289+
BaseModelType.StableDiffusion1,
1290+
BaseModelType.StableDiffusion2,
1291+
BaseModelType.StableDiffusionXL,
1292+
}:
1293+
raise ValueError(f"Attempted to get upcast_attention flag for non-UNet model base '{base}'")
1294+
1295+
if base is BaseModelType.StableDiffusion2 and prediction_type is SchedulerPredictionType.VPrediction:
1296+
# SD2 v-prediction models need upcast_attention to be True
1297+
return True
1298+
1299+
return False
1300+
12691301
class IPAdapterConfigBase(ABC, BaseModel):
12701302
type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter)
12711303

0 commit comments

Comments
 (0)