@@ -1094,7 +1094,6 @@ class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbe
1094
1094
BaseModelType .StableDiffusion3 ,
1095
1095
BaseModelType .StableDiffusionXL ,
1096
1096
BaseModelType .StableDiffusionXLRefiner ,
1097
- BaseModelType .Flux ,
1098
1097
BaseModelType .CogView4 ,
1099
1098
]
1100
1099
@@ -1104,6 +1103,157 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
1104
1103
1105
1104
base : MainDiffusers_SupportedBases = Field ()
1106
1105
1106
+ @classmethod
1107
+ def from_model_on_disk (cls , mod : ModelOnDisk , fields : dict [str , Any ]) -> Self :
1108
+ _validate_is_dir (cls , mod )
1109
+
1110
+ _validate_override_fields (cls , fields )
1111
+
1112
+ _validate_class_name (
1113
+ cls ,
1114
+ mod .path / "config.json" ,
1115
+ {
1116
+ "StableDiffusionPipeline" ,
1117
+ "StableDiffusionInpaintPipeline" ,
1118
+ "StableDiffusionXLPipeline" ,
1119
+ "StableDiffusionXLImg2ImgPipeline" ,
1120
+ "StableDiffusionXLInpaintPipeline" ,
1121
+ "StableDiffusion3Pipeline" ,
1122
+ "LatentConsistencyModelPipeline" ,
1123
+ "SD3Transformer2DModel" ,
1124
+ "CogView4Pipeline" ,
1125
+ },
1126
+ )
1127
+
1128
+ base = fields .get ("base" ) or cls ._get_base_or_raise (mod )
1129
+
1130
+ return cls (** fields , base = base )
1131
+
1132
+ @classmethod
1133
+ def _get_base_or_raise (cls , mod : ModelOnDisk ) -> MainDiffusers_SupportedBases :
1134
+ # Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
1135
+ unet_config_path = mod .path / "unet" / "config.json"
1136
+ if unet_config_path .exists ():
1137
+ with open (unet_config_path ) as file :
1138
+ unet_conf = json .load (file )
1139
+ cross_attention_dim = unet_conf .get ("cross_attention_dim" )
1140
+ match cross_attention_dim :
1141
+ case 768 :
1142
+ return BaseModelType .StableDiffusion1
1143
+ case 1024 :
1144
+ return BaseModelType .StableDiffusion2
1145
+ case 1280 :
1146
+ return BaseModelType .StableDiffusionXLRefiner
1147
+ case 2048 :
1148
+ return BaseModelType .StableDiffusionXL
1149
+ case _:
1150
+ raise NotAMatch (cls , f"unrecognized cross_attention_dim { cross_attention_dim } " )
1151
+
1152
+ # Handle pipelines with a transformer (i.e. SD3).
1153
+ transformer_config_path = mod .path / "transformer" / "config.json"
1154
+ if transformer_config_path .exists ():
1155
+ class_name = _get_class_name_from_config (cls , transformer_config_path )
1156
+ match class_name :
1157
+ case "SD3Transformer2DModel" :
1158
+ return BaseModelType .StableDiffusion3
1159
+ case "CogView4Transformer2DModel" :
1160
+ return BaseModelType .CogView4
1161
+ case _:
1162
+ raise NotAMatch (cls , f"unrecognized transformer class name { class_name } " )
1163
+
1164
+ raise NotAMatch (cls , "unable to determine base type" )
1165
+
1166
+ @classmethod
1167
+ def _get_scheduler_prediction_type_or_raise (cls , mod : ModelOnDisk , base : BaseModelType ) -> SchedulerPredictionType :
1168
+ if base not in {
1169
+ BaseModelType .StableDiffusion1 ,
1170
+ BaseModelType .StableDiffusion2 ,
1171
+ BaseModelType .StableDiffusionXL ,
1172
+ }:
1173
+ raise ValueError (f"Attempted to get scheduler prediction type for non-UNet model base '{ base } '" )
1174
+
1175
+ scheduler_conf = _get_config_or_raise (cls , mod .path / "scheduler" / "scheduler_config.json" )
1176
+
1177
+ # TODO(psyche): Is epsilon the right default or should we raise if it's not present?
1178
+ prediction_type = scheduler_conf .get ("prediction_type" , "epsilon" )
1179
+
1180
+ match prediction_type :
1181
+ case "v_prediction" :
1182
+ return SchedulerPredictionType .VPrediction
1183
+ case "epsilon" :
1184
+ return SchedulerPredictionType .Epsilon
1185
+ case _:
1186
+ raise NotAMatch (cls , f"unrecognized scheduler prediction type { prediction_type } " )
1187
+
1188
+ @classmethod
1189
+ def _get_variant_or_raise (cls , mod : ModelOnDisk , base : BaseModelType ) -> ModelVariantType :
1190
+ if base not in {
1191
+ BaseModelType .StableDiffusion1 ,
1192
+ BaseModelType .StableDiffusion2 ,
1193
+ BaseModelType .StableDiffusionXL ,
1194
+ }:
1195
+ raise ValueError (f"Attempted to get variant for model base '{ base } ' but it does not have variants" )
1196
+
1197
+ unet_config = _get_config_or_raise (cls , mod .path / "unet" / "config.json" )
1198
+ in_channels = unet_config .get ("in_channels" )
1199
+
1200
+ match in_channels :
1201
+ case 4 :
1202
+ return ModelVariantType .Normal
1203
+ case 5 :
1204
+ if base is not BaseModelType .StableDiffusion2 :
1205
+ raise NotAMatch (cls , "in_channels=5 is only valid for Stable Diffusion 2 models" )
1206
+ return ModelVariantType .Depth
1207
+ case 9 :
1208
+ return ModelVariantType .Inpaint
1209
+ case _:
1210
+ raise NotAMatch (cls , f"unrecognized unet in_channels { in_channels } " )
1211
+
1212
+ @classmethod
1213
+ def _get_submodels_or_raise (cls , mod : ModelOnDisk , base : BaseModelType ) -> dict [SubModelType , SubmodelDefinition ]:
1214
+ if base is not BaseModelType .StableDiffusion3 :
1215
+ raise ValueError (f"Attempted to get submodels for non-SD3 model base '{ base } '" )
1216
+
1217
+ # Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json
1218
+ config = _get_config_or_raise (cls , mod .path / "model_index.json" )
1219
+
1220
+ submodels : dict [SubModelType , SubmodelDefinition ] = {}
1221
+
1222
+ for key , value in config .items ():
1223
+ # Anything that starts with an underscore is top-level metadata, not a submodel
1224
+ if key .startswith ("_" ) or not (isinstance (value , list ) and len (value ) == 2 ):
1225
+ continue
1226
+ # The key is something like "transformer" and is a submodel - it will be in a dir of the same name.
1227
+ # The value value is something like ["diffusers", "SD3Transformer2DModel"]
1228
+ _library_name , class_name = value
1229
+
1230
+ match class_name :
1231
+ case "CLIPTextModelWithProjection" :
1232
+ model_type = ModelType .CLIPEmbed
1233
+ path_or_prefix = (mod .path / key ).resolve ().as_posix ()
1234
+
1235
+ # We need to read the config to determine the variant of the CLIP model.
1236
+ clip_embed_config = _get_config_or_raise (cls , mod .path / key / "config.json" )
1237
+ variant = _get_clip_variant_type_from_config (clip_embed_config )
1238
+ submodels [SubModelType (key )] = SubmodelDefinition (
1239
+ path_or_prefix = path_or_prefix ,
1240
+ model_type = model_type ,
1241
+ variant = variant ,
1242
+ )
1243
+ case "SD3Transformer2DModel" :
1244
+ model_type = ModelType .Main
1245
+ path_or_prefix = (mod .path / key ).resolve ().as_posix ()
1246
+ variant = None
1247
+ submodels [SubModelType (key )] = SubmodelDefinition (
1248
+ path_or_prefix = path_or_prefix ,
1249
+ model_type = model_type ,
1250
+ variant = variant ,
1251
+ )
1252
+ case _:
1253
+ pass
1254
+
1255
+ return submodels
1256
+
1107
1257
1108
1258
class IPAdapterConfigBase (ABC , BaseModel ):
1109
1259
type : Literal [ModelType .IPAdapter ] = Field (default = ModelType .IPAdapter )
@@ -1231,27 +1381,27 @@ def _get_base_or_raise(cls, mod: ModelOnDisk) -> IPAdapterCheckpoint_SupportedBa
1231
1381
raise NotAMatch (cls , f"unrecognized cross attention dimension { cross_attention_dim } " )
1232
1382
1233
1383
1384
+ def _get_clip_variant_type_from_config (config : dict [str , Any ]) -> ClipVariantType | None :
1385
+ try :
1386
+ hidden_size = config .get ("hidden_size" )
1387
+ match hidden_size :
1388
+ case 1280 :
1389
+ return ClipVariantType .G
1390
+ case 768 :
1391
+ return ClipVariantType .L
1392
+ case _:
1393
+ return None
1394
+ except Exception :
1395
+ return None
1396
+
1397
+
1234
1398
class CLIPEmbedDiffusersConfig (DiffusersConfigBase ):
1235
1399
"""Model config for Clip Embeddings."""
1236
1400
1237
1401
base : Literal [BaseModelType .Any ] = Field (default = BaseModelType .Any )
1238
1402
type : Literal [ModelType .CLIPEmbed ] = Field (default = ModelType .CLIPEmbed )
1239
1403
format : Literal [ModelFormat .Diffusers ] = Field (default = ModelFormat .Diffusers )
1240
1404
1241
- @classmethod
1242
- def _get_clip_variant_type (cls , config : dict [str , Any ]) -> ClipVariantType | None :
1243
- try :
1244
- hidden_size = config .get ("hidden_size" )
1245
- match hidden_size :
1246
- case 1280 :
1247
- return ClipVariantType .G
1248
- case 768 :
1249
- return ClipVariantType .L
1250
- case _:
1251
- return None
1252
- except Exception :
1253
- return None
1254
-
1255
1405
1256
1406
class CLIPGEmbedDiffusersConfig (CLIPEmbedDiffusersConfig , ModelConfigBase ):
1257
1407
"""Model config for CLIP-G Embeddings."""
@@ -1269,7 +1419,13 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1269
1419
_validate_override_fields (cls , fields )
1270
1420
1271
1421
_validate_class_name (
1272
- cls , mod .path / "config.json" , {"CLIPModel" , "CLIPTextModel" , "CLIPTextModelWithProjection" }
1422
+ cls ,
1423
+ mod .path / "config.json" ,
1424
+ {
1425
+ "CLIPModel" ,
1426
+ "CLIPTextModel" ,
1427
+ "CLIPTextModelWithProjection" ,
1428
+ },
1273
1429
)
1274
1430
1275
1431
cls ._validate_clip_g_variant (mod )
@@ -1279,7 +1435,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1279
1435
@classmethod
1280
1436
def _validate_clip_g_variant (cls , mod : ModelOnDisk ) -> None :
1281
1437
config = _get_config_or_raise (cls , mod .path / "config.json" )
1282
- clip_variant = cls . _get_clip_variant_type (config )
1438
+ clip_variant = _get_clip_variant_type_from_config (config )
1283
1439
1284
1440
if clip_variant is not ClipVariantType .G :
1285
1441
raise NotAMatch (cls , "model does not match CLIP-G heuristics" )
@@ -1301,7 +1457,13 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1301
1457
_validate_override_fields (cls , fields )
1302
1458
1303
1459
_validate_class_name (
1304
- cls , mod .path / "config.json" , {"CLIPModel" , "CLIPTextModel" , "CLIPTextModelWithProjection" }
1460
+ cls ,
1461
+ mod .path / "config.json" ,
1462
+ {
1463
+ "CLIPModel" ,
1464
+ "CLIPTextModel" ,
1465
+ "CLIPTextModelWithProjection" ,
1466
+ },
1305
1467
)
1306
1468
1307
1469
cls ._validate_clip_l_variant (mod )
@@ -1311,7 +1473,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1311
1473
@classmethod
1312
1474
def _validate_clip_l_variant (cls , mod : ModelOnDisk ) -> None :
1313
1475
config = _get_config_or_raise (cls , mod .path / "config.json" )
1314
- clip_variant = cls . _get_clip_variant_type (config )
1476
+ clip_variant = _get_clip_variant_type_from_config (config )
1315
1477
1316
1478
if clip_variant is not ClipVariantType .L :
1317
1479
raise NotAMatch (cls , "model does not match CLIP-G heuristics" )
@@ -1330,7 +1492,13 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1330
1492
1331
1493
_validate_override_fields (cls , fields )
1332
1494
1333
- _validate_class_name (cls , mod .path / "config.json" , {"CLIPVisionModelWithProjection" })
1495
+ _validate_class_name (
1496
+ cls ,
1497
+ mod .path / "config.json" ,
1498
+ {
1499
+ "CLIPVisionModelWithProjection" ,
1500
+ },
1501
+ )
1334
1502
1335
1503
return cls (** fields )
1336
1504
@@ -1354,7 +1522,13 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1354
1522
1355
1523
_validate_override_fields (cls , fields )
1356
1524
1357
- _validate_class_name (cls , mod .path / "config.json" , {"T2IAdapter" })
1525
+ _validate_class_name (
1526
+ cls ,
1527
+ mod .path / "config.json" ,
1528
+ {
1529
+ "T2IAdapter" ,
1530
+ },
1531
+ )
1358
1532
1359
1533
base = fields .get ("base" ) or cls ._get_base_or_raise (mod )
1360
1534
@@ -1421,7 +1595,13 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1421
1595
1422
1596
_validate_override_fields (cls , fields )
1423
1597
1424
- _validate_class_name (cls , mod .path / "config.json" , {"SiglipModel" })
1598
+ _validate_class_name (
1599
+ cls ,
1600
+ mod .path / "config.json" ,
1601
+ {
1602
+ "SiglipModel" ,
1603
+ },
1604
+ )
1425
1605
1426
1606
return cls (** fields )
1427
1607
@@ -1458,7 +1638,13 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1458
1638
1459
1639
_validate_override_fields (cls , fields )
1460
1640
1461
- _validate_class_name (cls , mod .path / "config.json" , {"LlavaOnevisionForConditionalGeneration" })
1641
+ _validate_class_name (
1642
+ cls ,
1643
+ mod .path / "config.json" ,
1644
+ {
1645
+ "LlavaOnevisionForConditionalGeneration" ,
1646
+ },
1647
+ )
1462
1648
1463
1649
return cls (** fields )
1464
1650
0 commit comments