@@ -1164,7 +1164,9 @@ def _get_base_or_raise(cls, mod: ModelOnDisk) -> MainDiffusers_SupportedBases:
1164
1164
raise NotAMatch (cls , "unable to determine base type" )
1165
1165
1166
1166
@classmethod
1167
- def _get_scheduler_prediction_type_or_raise (cls , mod : ModelOnDisk , base : BaseModelType ) -> SchedulerPredictionType :
1167
+ def _get_scheduler_prediction_type_or_raise (
1168
+ cls , mod : ModelOnDisk , base : MainDiffusers_SupportedBases
1169
+ ) -> SchedulerPredictionType :
1168
1170
if base not in {
1169
1171
BaseModelType .StableDiffusion1 ,
1170
1172
BaseModelType .StableDiffusion2 ,
@@ -1186,7 +1188,7 @@ def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk, base: BaseMod
1186
1188
raise NotAMatch (cls , f"unrecognized scheduler prediction type { prediction_type } " )
1187
1189
1188
1190
@classmethod
1189
- def _get_variant_or_raise (cls , mod : ModelOnDisk , base : BaseModelType ) -> ModelVariantType :
1191
+ def _get_variant_or_raise (cls , mod : ModelOnDisk , base : MainDiffusers_SupportedBases ) -> ModelVariantType :
1190
1192
if base not in {
1191
1193
BaseModelType .StableDiffusion1 ,
1192
1194
BaseModelType .StableDiffusion2 ,
@@ -1197,20 +1199,29 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> ModelVa
1197
1199
unet_config = _get_config_or_raise (cls , mod .path / "unet" / "config.json" )
1198
1200
in_channels = unet_config .get ("in_channels" )
1199
1201
1200
- match in_channels :
1201
- case 4 :
1202
- return ModelVariantType .Normal
1203
- case 5 :
1204
- if base is not BaseModelType .StableDiffusion2 :
1205
- raise NotAMatch (cls , "in_channels=5 is only valid for Stable Diffusion 2 models" )
1206
- return ModelVariantType .Depth
1207
- case 9 :
1208
- return ModelVariantType .Inpaint
1209
- case _:
1210
- raise NotAMatch (cls , f"unrecognized unet in_channels { in_channels } " )
1202
+ if base is BaseModelType .StableDiffusion2 :
1203
+ match in_channels :
1204
+ case 4 :
1205
+ return ModelVariantType .Normal
1206
+ case 9 :
1207
+ return ModelVariantType .Inpaint
1208
+ case 5 :
1209
+ return ModelVariantType .Depth
1210
+ case _:
1211
+ raise NotAMatch (cls , f"unrecognized unet in_channels { in_channels } for base '{ base } '" )
1212
+ else :
1213
+ match in_channels :
1214
+ case 4 :
1215
+ return ModelVariantType .Normal
1216
+ case 9 :
1217
+ return ModelVariantType .Inpaint
1218
+ case _:
1219
+ raise NotAMatch (cls , f"unrecognized unet in_channels { in_channels } for base '{ base } '" )
1211
1220
1212
1221
@classmethod
1213
- def _get_submodels_or_raise (cls , mod : ModelOnDisk , base : BaseModelType ) -> dict [SubModelType , SubmodelDefinition ]:
1222
+ def _get_submodels_or_raise (
1223
+ cls , mod : ModelOnDisk , base : MainDiffusers_SupportedBases
1224
+ ) -> dict [SubModelType , SubmodelDefinition ]:
1214
1225
if base is not BaseModelType .StableDiffusion3 :
1215
1226
raise ValueError (f"Attempted to get submodels for non-SD3 model base '{ base } '" )
1216
1227
0 commit comments