@@ -1126,12 +1126,29 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1126
1126
)
1127
1127
1128
1128
base = fields .get ("base" ) or cls ._get_base_or_raise (mod )
1129
+ if base in {
1130
+ BaseModelType .StableDiffusion1 ,
1131
+ BaseModelType .StableDiffusion2 ,
1132
+ BaseModelType .StableDiffusionXL ,
1133
+ }:
1134
+ variant = fields .get ("variant" ) or cls ._get_variant_or_raise (mod , base )
1135
+ prediction_type = fields .get ("prediction_type" ) or cls ._get_scheduler_prediction_type_or_raise (mod , base )
1136
+ upcast_attention = fields .get ("upcast_attention" ) or cls ._get_upcast_attention_or_raise (base , prediction_type )
1137
+ else :
1138
+ variant = None
1139
+ prediction_type = None
1140
+ upcast_attention = False
1129
1141
1130
- return cls (** fields , base = base )
1142
+ if base is BaseModelType .StableDiffusion3 :
1143
+ submodels = fields .get ("submodels" ) or cls ._get_submodels_or_raise (mod , base )
1144
+ else :
1145
+ submodels = None
1146
+
1147
+ return cls (** fields , base = base , variant = variant , prediction_type = prediction_type , upcast_attention = upcast_attention , submodels = submodels ,)
1131
1148
1132
1149
@classmethod
1133
1150
def _get_base_or_raise (cls , mod : ModelOnDisk ) -> MainDiffusers_SupportedBases :
1134
- # Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
1151
+ # Handle pipelines with a UNet (i.e SD 1.x, SD2.x , SDXL).
1135
1152
unet_config_path = mod .path / "unet" / "config.json"
1136
1153
if unet_config_path .exists ():
1137
1154
with open (unet_config_path ) as file :
@@ -1172,7 +1189,7 @@ def _get_scheduler_prediction_type_or_raise(
1172
1189
BaseModelType .StableDiffusion2 ,
1173
1190
BaseModelType .StableDiffusionXL ,
1174
1191
}:
1175
- raise ValueError (f"Attempted to get scheduler prediction type for non-UNet model base '{ base } '" )
1192
+ raise ValueError (f"Attempted to get scheduler prediction_type for non-UNet model base '{ base } '" )
1176
1193
1177
1194
scheduler_conf = _get_config_or_raise (cls , mod .path / "scheduler" / "scheduler_config.json" )
1178
1195
@@ -1185,7 +1202,7 @@ def _get_scheduler_prediction_type_or_raise(
1185
1202
case "epsilon" :
1186
1203
return SchedulerPredictionType .Epsilon
1187
1204
case _:
1188
- raise NotAMatch (cls , f"unrecognized scheduler prediction type { prediction_type } " )
1205
+ raise NotAMatch (cls , f"unrecognized scheduler prediction_type { prediction_type } " )
1189
1206
1190
1207
@classmethod
1191
1208
def _get_variant_or_raise (cls , mod : ModelOnDisk , base : MainDiffusers_SupportedBases ) -> ModelVariantType :
@@ -1266,6 +1283,21 @@ def _get_submodels_or_raise(
1266
1283
return submodels
1267
1284
1268
1285
1286
+ @classmethod
1287
+ def _get_upcast_attention_or_raise (cls , base : MainDiffusers_SupportedBases , prediction_type : SchedulerPredictionType ) -> bool :
1288
+ if base not in {
1289
+ BaseModelType .StableDiffusion1 ,
1290
+ BaseModelType .StableDiffusion2 ,
1291
+ BaseModelType .StableDiffusionXL ,
1292
+ }:
1293
+ raise ValueError (f"Attempted to get upcast_attention flag for non-UNet model base '{ base } '" )
1294
+
1295
+ if base is BaseModelType .StableDiffusion2 and prediction_type is SchedulerPredictionType .VPrediction :
1296
+ # SD2 v-prediction models need upcast_attention to be True
1297
+ return True
1298
+
1299
+ return False
1300
+
1269
1301
class IPAdapterConfigBase (ABC , BaseModel ):
1270
1302
type : Literal [ModelType .IPAdapter ] = Field (default = ModelType .IPAdapter )
1271
1303
0 commit comments