Skip to content

Commit 1d3f6c4

Browse files
feat(mm): make config_path optional
1 parent 8217fd9 commit 1d3f6c4

File tree

4 files changed

+14
-10
lines changed

4 files changed

+14
-10
lines changed

invokeai/app/invocations/flux_denoise.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
unpack,
4949
)
5050
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
51-
from invokeai.backend.model_manager.taxonomy import ModelFormat, ModelVariantType
51+
from invokeai.backend.model_manager.taxonomy import FluxVariantType, ModelFormat
5252
from invokeai.backend.patches.layer_patcher import LayerPatcher
5353
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
5454
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -232,7 +232,7 @@ def _run_diffusion(
232232
)
233233

234234
transformer_config = context.models.get_config(self.transformer.transformer)
235-
is_schnell = "schnell" in getattr(transformer_config, "config_path", "")
235+
is_schnell = transformer_config.variant is FluxVariantType.Schnell
236236

237237
# Calculate the timestep schedule.
238238
timesteps = get_schedule(
@@ -277,7 +277,7 @@ def _run_diffusion(
277277

278278
# Prepare the extra image conditioning tensor (img_cond) for either FLUX structural control or FLUX Fill.
279279
img_cond: torch.Tensor | None = None
280-
is_flux_fill = transformer_config.variant == ModelVariantType.Inpaint # type: ignore
280+
is_flux_fill = transformer_config.variant is FluxVariantType.DevFill
281281
if is_flux_fill:
282282
img_cond = self._prep_flux_fill_img_cond(
283283
context, device=TorchDevice.choose_torch_device(), dtype=inference_dtype

invokeai/backend/model_manager/config.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,6 @@ def __init_subclass__(cls, **kwargs):
198198
super().__init_subclass__(**kwargs)
199199
if issubclass(cls, LegacyProbeMixin):
200200
ModelConfigBase.USING_LEGACY_PROBE.add(cls)
201-
# Cannot use `elif isinstance(cls, UnknownModelConfig)` because UnknownModelConfig is not defined yet
202201
else:
203202
ModelConfigBase.USING_CLASSIFY_API.add(cls)
204203

@@ -346,11 +345,16 @@ class CheckpointConfigBase(ABC, BaseModel):
346345
"""Base class for checkpoint-style models."""
347346

348347
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field(
349-
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
348+
description="Format of the provided checkpoint model",
349+
default=ModelFormat.Checkpoint,
350350
)
351-
config_path: str = Field(description="path to the checkpoint model config file")
352-
converted_at: Optional[float] = Field(
353-
description="When this model was last converted to diffusers", default_factory=time.time
351+
config_path: str | None = Field(
352+
description="path to the checkpoint model config file",
353+
default=None,
354+
)
355+
converted_at: float | None = Field(
356+
description="When this model was last converted to diffusers",
357+
default_factory=time.time,
354358
)
355359

356360

invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/MainModelPicker.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ export const MainModelPicker = memo(() => {
2727
() =>
2828
selectedModelConfig &&
2929
isCheckpointMainModelConfig(selectedModelConfig) &&
30-
selectedModelConfig.config_path === 'flux-dev',
30+
selectedModelConfig.variant === 'flux_dev',
3131
[selectedModelConfig]
3232
);
3333

invokeai/frontend/web/src/features/ui/layouts/InitialStateMainModelPicker.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ export const InitialStateMainModelPicker = memo(() => {
2626
() =>
2727
selectedModelConfig &&
2828
isCheckpointMainModelConfig(selectedModelConfig) &&
29-
selectedModelConfig.config_path === 'flux-dev',
29+
selectedModelConfig.variant === 'flux_dev',
3030
[selectedModelConfig]
3131
);
3232

0 commit comments

Comments
 (0)