Skip to content

Commit 5f45a9c

Browse files
feat(mm): wip port of main models to new api
1 parent 1268b23 commit 5f45a9c

File tree

1 file changed

+209
-23
lines changed

1 file changed

+209
-23
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 209 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,6 @@ class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbe
10941094
BaseModelType.StableDiffusion3,
10951095
BaseModelType.StableDiffusionXL,
10961096
BaseModelType.StableDiffusionXLRefiner,
1097-
BaseModelType.Flux,
10981097
BaseModelType.CogView4,
10991098
]
11001099

@@ -1104,6 +1103,157 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
11041103

11051104
base: MainDiffusers_SupportedBases = Field()
11061105

1106+
@classmethod
1107+
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1108+
_validate_is_dir(cls, mod)
1109+
1110+
_validate_override_fields(cls, fields)
1111+
1112+
_validate_class_name(
1113+
cls,
1114+
mod.path / "config.json",
1115+
{
1116+
"StableDiffusionPipeline",
1117+
"StableDiffusionInpaintPipeline",
1118+
"StableDiffusionXLPipeline",
1119+
"StableDiffusionXLImg2ImgPipeline",
1120+
"StableDiffusionXLInpaintPipeline",
1121+
"StableDiffusion3Pipeline",
1122+
"LatentConsistencyModelPipeline",
1123+
"SD3Transformer2DModel",
1124+
"CogView4Pipeline",
1125+
},
1126+
)
1127+
1128+
base = fields.get("base") or cls._get_base_or_raise(mod)
1129+
1130+
return cls(**fields, base=base)
1131+
1132+
@classmethod
1133+
def _get_base_or_raise(cls, mod: ModelOnDisk) -> MainDiffusers_SupportedBases:
1134+
# Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
1135+
unet_config_path = mod.path / "unet" / "config.json"
1136+
if unet_config_path.exists():
1137+
with open(unet_config_path) as file:
1138+
unet_conf = json.load(file)
1139+
cross_attention_dim = unet_conf.get("cross_attention_dim")
1140+
match cross_attention_dim:
1141+
case 768:
1142+
return BaseModelType.StableDiffusion1
1143+
case 1024:
1144+
return BaseModelType.StableDiffusion2
1145+
case 1280:
1146+
return BaseModelType.StableDiffusionXLRefiner
1147+
case 2048:
1148+
return BaseModelType.StableDiffusionXL
1149+
case _:
1150+
raise NotAMatch(cls, f"unrecognized cross_attention_dim {cross_attention_dim}")
1151+
1152+
# Handle pipelines with a transformer (i.e. SD3).
1153+
transformer_config_path = mod.path / "transformer" / "config.json"
1154+
if transformer_config_path.exists():
1155+
class_name = _get_class_name_from_config(cls, transformer_config_path)
1156+
match class_name:
1157+
case "SD3Transformer2DModel":
1158+
return BaseModelType.StableDiffusion3
1159+
case "CogView4Transformer2DModel":
1160+
return BaseModelType.CogView4
1161+
case _:
1162+
raise NotAMatch(cls, f"unrecognized transformer class name {class_name}")
1163+
1164+
raise NotAMatch(cls, "unable to determine base type")
1165+
1166+
@classmethod
1167+
def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> SchedulerPredictionType:
1168+
if base not in {
1169+
BaseModelType.StableDiffusion1,
1170+
BaseModelType.StableDiffusion2,
1171+
BaseModelType.StableDiffusionXL,
1172+
}:
1173+
raise ValueError(f"Attempted to get scheduler prediction type for non-UNet model base '{base}'")
1174+
1175+
scheduler_conf = _get_config_or_raise(cls, mod.path / "scheduler" / "scheduler_config.json")
1176+
1177+
# TODO(psyche): Is epsilon the right default or should we raise if it's not present?
1178+
prediction_type = scheduler_conf.get("prediction_type", "epsilon")
1179+
1180+
match prediction_type:
1181+
case "v_prediction":
1182+
return SchedulerPredictionType.VPrediction
1183+
case "epsilon":
1184+
return SchedulerPredictionType.Epsilon
1185+
case _:
1186+
raise NotAMatch(cls, f"unrecognized scheduler prediction type {prediction_type}")
1187+
1188+
@classmethod
1189+
def _get_variant_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> ModelVariantType:
1190+
if base not in {
1191+
BaseModelType.StableDiffusion1,
1192+
BaseModelType.StableDiffusion2,
1193+
BaseModelType.StableDiffusionXL,
1194+
}:
1195+
raise ValueError(f"Attempted to get variant for model base '{base}' but it does not have variants")
1196+
1197+
unet_config = _get_config_or_raise(cls, mod.path / "unet" / "config.json")
1198+
in_channels = unet_config.get("in_channels")
1199+
1200+
match in_channels:
1201+
case 4:
1202+
return ModelVariantType.Normal
1203+
case 5:
1204+
if base is not BaseModelType.StableDiffusion2:
1205+
raise NotAMatch(cls, "in_channels=5 is only valid for Stable Diffusion 2 models")
1206+
return ModelVariantType.Depth
1207+
case 9:
1208+
return ModelVariantType.Inpaint
1209+
case _:
1210+
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels}")
1211+
1212+
@classmethod
1213+
def _get_submodels_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> dict[SubModelType, SubmodelDefinition]:
1214+
if base is not BaseModelType.StableDiffusion3:
1215+
raise ValueError(f"Attempted to get submodels for non-SD3 model base '{base}'")
1216+
1217+
# Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json
1218+
config = _get_config_or_raise(cls, mod.path / "model_index.json")
1219+
1220+
submodels: dict[SubModelType, SubmodelDefinition] = {}
1221+
1222+
for key, value in config.items():
1223+
# Anything that starts with an underscore is top-level metadata, not a submodel
1224+
if key.startswith("_") or not (isinstance(value, list) and len(value) == 2):
1225+
continue
1226+
# The key is something like "transformer" and is a submodel - it will be in a dir of the same name.
1227+
# The value value is something like ["diffusers", "SD3Transformer2DModel"]
1228+
_library_name, class_name = value
1229+
1230+
match class_name:
1231+
case "CLIPTextModelWithProjection":
1232+
model_type = ModelType.CLIPEmbed
1233+
path_or_prefix = (mod.path / key).resolve().as_posix()
1234+
1235+
# We need to read the config to determine the variant of the CLIP model.
1236+
clip_embed_config = _get_config_or_raise(cls, mod.path / key / "config.json")
1237+
variant = _get_clip_variant_type_from_config(clip_embed_config)
1238+
submodels[SubModelType(key)] = SubmodelDefinition(
1239+
path_or_prefix=path_or_prefix,
1240+
model_type=model_type,
1241+
variant=variant,
1242+
)
1243+
case "SD3Transformer2DModel":
1244+
model_type = ModelType.Main
1245+
path_or_prefix = (mod.path / key).resolve().as_posix()
1246+
variant = None
1247+
submodels[SubModelType(key)] = SubmodelDefinition(
1248+
path_or_prefix=path_or_prefix,
1249+
model_type=model_type,
1250+
variant=variant,
1251+
)
1252+
case _:
1253+
pass
1254+
1255+
return submodels
1256+
11071257

11081258
class IPAdapterConfigBase(ABC, BaseModel):
11091259
type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter)
@@ -1231,27 +1381,27 @@ def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterCheckpoint_SupportedBa
12311381
raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
12321382

12331383

1384+
def _get_clip_variant_type_from_config(config: dict[str, Any]) -> ClipVariantType | None:
1385+
try:
1386+
hidden_size = config.get("hidden_size")
1387+
match hidden_size:
1388+
case 1280:
1389+
return ClipVariantType.G
1390+
case 768:
1391+
return ClipVariantType.L
1392+
case _:
1393+
return None
1394+
except Exception:
1395+
return None
1396+
1397+
12341398
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
12351399
"""Model config for Clip Embeddings."""
12361400

12371401
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
12381402
type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed)
12391403
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
12401404

1241-
@classmethod
1242-
def _get_clip_variant_type(cls, config: dict[str, Any]) -> ClipVariantType | None:
1243-
try:
1244-
hidden_size = config.get("hidden_size")
1245-
match hidden_size:
1246-
case 1280:
1247-
return ClipVariantType.G
1248-
case 768:
1249-
return ClipVariantType.L
1250-
case _:
1251-
return None
1252-
except Exception:
1253-
return None
1254-
12551405

12561406
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
12571407
"""Model config for CLIP-G Embeddings."""
@@ -1269,7 +1419,13 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
12691419
_validate_override_fields(cls, fields)
12701420

12711421
_validate_class_name(
1272-
cls, mod.path / "config.json", {"CLIPModel", "CLIPTextModel", "CLIPTextModelWithProjection"}
1422+
cls,
1423+
mod.path / "config.json",
1424+
{
1425+
"CLIPModel",
1426+
"CLIPTextModel",
1427+
"CLIPTextModelWithProjection",
1428+
},
12731429
)
12741430

12751431
cls._validate_clip_g_variant(mod)
@@ -1279,7 +1435,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
12791435
@classmethod
12801436
def _validate_clip_g_variant(cls, mod: ModelOnDisk) -> None:
12811437
config = _get_config_or_raise(cls, mod.path / "config.json")
1282-
clip_variant = cls._get_clip_variant_type(config)
1438+
clip_variant = _get_clip_variant_type_from_config(config)
12831439

12841440
if clip_variant is not ClipVariantType.G:
12851441
raise NotAMatch(cls, "model does not match CLIP-G heuristics")
@@ -1301,7 +1457,13 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
13011457
_validate_override_fields(cls, fields)
13021458

13031459
_validate_class_name(
1304-
cls, mod.path / "config.json", {"CLIPModel", "CLIPTextModel", "CLIPTextModelWithProjection"}
1460+
cls,
1461+
mod.path / "config.json",
1462+
{
1463+
"CLIPModel",
1464+
"CLIPTextModel",
1465+
"CLIPTextModelWithProjection",
1466+
},
13051467
)
13061468

13071469
cls._validate_clip_l_variant(mod)
@@ -1311,7 +1473,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
13111473
@classmethod
13121474
def _validate_clip_l_variant(cls, mod: ModelOnDisk) -> None:
13131475
config = _get_config_or_raise(cls, mod.path / "config.json")
1314-
clip_variant = cls._get_clip_variant_type(config)
1476+
clip_variant = _get_clip_variant_type_from_config(config)
13151477

13161478
if clip_variant is not ClipVariantType.L:
13171479
raise NotAMatch(cls, "model does not match CLIP-G heuristics")
@@ -1330,7 +1492,13 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
13301492

13311493
_validate_override_fields(cls, fields)
13321494

1333-
_validate_class_name(cls, mod.path / "config.json", {"CLIPVisionModelWithProjection"})
1495+
_validate_class_name(
1496+
cls,
1497+
mod.path / "config.json",
1498+
{
1499+
"CLIPVisionModelWithProjection",
1500+
},
1501+
)
13341502

13351503
return cls(**fields)
13361504

@@ -1354,7 +1522,13 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
13541522

13551523
_validate_override_fields(cls, fields)
13561524

1357-
_validate_class_name(cls, mod.path / "config.json", {"T2IAdapter"})
1525+
_validate_class_name(
1526+
cls,
1527+
mod.path / "config.json",
1528+
{
1529+
"T2IAdapter",
1530+
},
1531+
)
13581532

13591533
base = fields.get("base") or cls._get_base_or_raise(mod)
13601534

@@ -1421,7 +1595,13 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
14211595

14221596
_validate_override_fields(cls, fields)
14231597

1424-
_validate_class_name(cls, mod.path / "config.json", {"SiglipModel"})
1598+
_validate_class_name(
1599+
cls,
1600+
mod.path / "config.json",
1601+
{
1602+
"SiglipModel",
1603+
},
1604+
)
14251605

14261606
return cls(**fields)
14271607

@@ -1458,7 +1638,13 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
14581638

14591639
_validate_override_fields(cls, fields)
14601640

1461-
_validate_class_name(cls, mod.path / "config.json", {"LlavaOnevisionForConditionalGeneration"})
1641+
_validate_class_name(
1642+
cls,
1643+
mod.path / "config.json",
1644+
{
1645+
"LlavaOnevisionForConditionalGeneration",
1646+
},
1647+
)
14621648

14631649
return cls(**fields)
14641650

0 commit comments

Comments
 (0)