@@ -1133,9 +1133,11 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1133
1133
}:
1134
1134
variant = fields .get ("variant" ) or cls ._get_variant_or_raise (mod , base )
1135
1135
prediction_type = fields .get ("prediction_type" ) or cls ._get_scheduler_prediction_type_or_raise (mod , base )
1136
- upcast_attention = fields .get ("upcast_attention" ) or cls ._get_upcast_attention_or_raise (base , prediction_type )
1136
+ upcast_attention = fields .get ("upcast_attention" ) or cls ._get_upcast_attention_or_raise (
1137
+ base , prediction_type
1138
+ )
1137
1139
else :
1138
- variant = None
1140
+ variant = None
1139
1141
prediction_type = None
1140
1142
upcast_attention = False
1141
1143
@@ -1144,7 +1146,16 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
1144
1146
else :
1145
1147
submodels = None
1146
1148
1147
- return cls (** fields , base = base , variant = variant , prediction_type = prediction_type , upcast_attention = upcast_attention , submodels = submodels ,)
1149
+ return cls (
1150
+ ** fields ,
1151
+ base = base ,
1152
+ # TODO(psyche): figure out variant/prediction_type/upcast_attention
1153
+ variant = variant ,
1154
+ prediction_type = prediction_type ,
1155
+ upcast_attention = upcast_attention ,
1156
+ # TODO(psyche): This is only for SD3 models - split up the config classes
1157
+ submodels = submodels ,
1158
+ )
1148
1159
1149
1160
@classmethod
1150
1161
def _get_base_or_raise (cls , mod : ModelOnDisk ) -> MainDiffusers_SupportedBases :
@@ -1282,9 +1293,10 @@ def _get_submodels_or_raise(
1282
1293
1283
1294
return submodels
1284
1295
1285
-
1286
1296
@classmethod
1287
- def _get_upcast_attention_or_raise (cls , base : MainDiffusers_SupportedBases , prediction_type : SchedulerPredictionType ) -> bool :
1297
+ def _get_upcast_attention_or_raise (
1298
+ cls , base : MainDiffusers_SupportedBases , prediction_type : SchedulerPredictionType
1299
+ ) -> bool :
1288
1300
if base not in {
1289
1301
BaseModelType .StableDiffusion1 ,
1290
1302
BaseModelType .StableDiffusion2 ,
@@ -1298,6 +1310,7 @@ def _get_upcast_attention_or_raise(cls, base: MainDiffusers_SupportedBases, pred
1298
1310
1299
1311
return False
1300
1312
1313
+
1301
1314
class IPAdapterConfigBase (ABC , BaseModel ):
1302
1315
type : Literal [ModelType .IPAdapter ] = Field (default = ModelType .IPAdapter )
1303
1316
0 commit comments