Skip to content

Commit 69efdc3

Browse files
docs(mm): add todos
1 parent 3a44fde commit 69efdc3

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,9 +1133,11 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
11331133
}:
11341134
variant = fields.get("variant") or cls._get_variant_or_raise(mod, base)
11351135
prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base)
1136-
upcast_attention = fields.get("upcast_attention") or cls._get_upcast_attention_or_raise(base, prediction_type)
1136+
upcast_attention = fields.get("upcast_attention") or cls._get_upcast_attention_or_raise(
1137+
base, prediction_type
1138+
)
11371139
else:
1138-
variant= None
1140+
variant = None
11391141
prediction_type = None
11401142
upcast_attention = False
11411143

@@ -1144,7 +1146,16 @@ def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
11441146
else:
11451147
submodels = None
11461148

1147-
return cls(**fields, base=base, variant=variant, prediction_type=prediction_type, upcast_attention=upcast_attention, submodels=submodels,)
1149+
return cls(
1150+
**fields,
1151+
base=base,
1152+
# TODO(psyche): figure out variant/prediction_type/upcast_attention
1153+
variant=variant,
1154+
prediction_type=prediction_type,
1155+
upcast_attention=upcast_attention,
1156+
# TODO(psyche): This is only for SD3 models - split up the config classes
1157+
submodels=submodels,
1158+
)
11481159

11491160
@classmethod
11501161
def _get_base_or_raise(cls, mod: ModelOnDisk) -> MainDiffusers_SupportedBases:
@@ -1282,9 +1293,10 @@ def _get_submodels_or_raise(
12821293

12831294
return submodels
12841295

1285-
12861296
@classmethod
1287-
def _get_upcast_attention_or_raise(cls, base: MainDiffusers_SupportedBases, prediction_type: SchedulerPredictionType) -> bool:
1297+
def _get_upcast_attention_or_raise(
1298+
cls, base: MainDiffusers_SupportedBases, prediction_type: SchedulerPredictionType
1299+
) -> bool:
12881300
if base not in {
12891301
BaseModelType.StableDiffusion1,
12901302
BaseModelType.StableDiffusion2,
@@ -1298,6 +1310,7 @@ def _get_upcast_attention_or_raise(cls, base: MainDiffusers_SupportedBases, pred
12981310

12991311
return False
13001312

1313+
13011314
class IPAdapterConfigBase(ABC, BaseModel):
13021315
type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter)
13031316

0 commit comments

Comments
 (0)