Skip to content

Commit 935fafe

Browse files
refactor(mm): make config classes narrow
Simpler logic to identify, less complexity to add new model, fewer useless attrs that do not relate to the model arch, etc
1 parent bab7f62 commit 935fafe

File tree

9 files changed

+605
-402
lines changed

9 files changed

+605
-402
lines changed

invokeai/app/invocations/flux_ip_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
1818
from invokeai.app.services.shared.invocation_context import InvocationContext
1919
from invokeai.backend.model_manager.config import (
20+
IPAdapter_InvokeAI_Config_Base,
2021
IPAdapterCheckpointConfig,
21-
IPAdapterInvokeAIConfig,
2222
)
2323
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
2424

@@ -68,7 +68,7 @@ def validate_begin_end_step_percent(self) -> Self:
6868
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
6969
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
7070
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
71-
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
71+
assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapterCheckpointConfig))
7272

7373
# Note: There is a IPAdapterInvokeAIConfig.image_encoder_model_id field, but it isn't trustworthy.
7474
image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model]

invokeai/app/invocations/ip_adapter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from invokeai.app.services.shared.invocation_context import InvocationContext
1414
from invokeai.backend.model_manager.config import (
1515
AnyModelConfig,
16+
IPAdapter_InvokeAI_Config_Base,
1617
IPAdapterCheckpointConfig,
17-
IPAdapterInvokeAIConfig,
1818
)
1919
from invokeai.backend.model_manager.starter_models import (
2020
StarterModel,
@@ -123,9 +123,9 @@ def validate_begin_end_step_percent(self) -> Self:
123123
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
124124
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
125125
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
126-
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
126+
assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapterCheckpointConfig))
127127

128-
if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
128+
if isinstance(ip_adapter_info, IPAdapter_InvokeAI_Config_Base):
129129
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
130130
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
131131
else:

invokeai/backend/model_manager/config.py

Lines changed: 545 additions & 358 deletions
Large diffs are not rendered by default.

invokeai/backend/model_manager/load/load_base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
import torch
1313

1414
from invokeai.app.services.config import InvokeAIAppConfig
15-
from invokeai.backend.model_manager.config import (
16-
AnyModelConfig,
17-
)
15+
from invokeai.backend.model_manager.config import AnyModelConfig
1816
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
1917
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
2018
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType

invokeai/backend/model_manager/load/model_loaders/controlnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from invokeai.backend.model_manager.config import (
99
AnyModelConfig,
10-
ControlNetCheckpointConfig,
10+
ControlNet_Checkpoint_Config_Base,
1111
)
1212
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
1313
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
@@ -46,7 +46,7 @@ def _load_model(
4646
config: AnyModelConfig,
4747
submodel_type: Optional[SubModelType] = None,
4848
) -> AnyModel:
49-
if isinstance(config, ControlNetCheckpointConfig):
49+
if isinstance(config, ControlNet_Checkpoint_Config_Base):
5050
return ControlNetModel.from_single_file(
5151
config.path,
5252
torch_dtype=self._torch_dtype,

invokeai/backend/model_manager/load/model_loaders/flux.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,26 +37,26 @@
3737
from invokeai.backend.model_manager.config import (
3838
AnyModelConfig,
3939
CheckpointConfigBase,
40-
CLIPEmbedDiffusersConfig,
41-
ControlNetCheckpointConfig,
42-
ControlNetDiffusersConfig,
43-
FLUX_Quantized_BnB_NF4_CheckpointConfig,
44-
FLUX_Quantized_GGUF_CheckpointConfig,
45-
FLUX_Unquantized_CheckpointConfig,
46-
FluxReduxConfig,
47-
IPAdapterCheckpointConfig,
48-
T5EncoderBnbQuantizedLlmInt8bConfig,
49-
T5EncoderConfig,
50-
VAECheckpointConfig,
40+
CLIPEmbed_Diffusers_Config_Base,
41+
ControlNet_Checkpoint_Config_Base,
42+
ControlNet_Diffusers_Config_Base,
43+
FLUXRedux_Checkpoint_Config,
44+
IPAdapter_Checkpoint_Config_Base,
45+
Main_FLUX_BnBNF4_Config,
46+
Main_FLUX_Checkpoint_Config,
47+
Main_FLUX_GGUF_Config,
48+
T5Encoder_BnBLLMint8_Config,
49+
T5Encoder_T5Encoder_Config,
50+
VAE_Checkpoint_Config_Base,
5151
)
5252
from invokeai.backend.model_manager.load.load_default import ModelLoader
5353
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
5454
from invokeai.backend.model_manager.taxonomy import (
5555
AnyModel,
5656
BaseModelType,
57+
FluxVariantType,
5758
ModelFormat,
5859
ModelType,
59-
ModelVariantType,
6060
SubModelType,
6161
)
6262
from invokeai.backend.model_manager.util.model_util import (
@@ -86,7 +86,7 @@ def _load_model(
8686
config: AnyModelConfig,
8787
submodel_type: Optional[SubModelType] = None,
8888
) -> AnyModel:
89-
if not isinstance(config, VAECheckpointConfig):
89+
if not isinstance(config, VAE_Checkpoint_Config_Base):
9090
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
9191
model_path = Path(config.path)
9292

@@ -116,7 +116,7 @@ def _load_model(
116116
config: AnyModelConfig,
117117
submodel_type: Optional[SubModelType] = None,
118118
) -> AnyModel:
119-
if not isinstance(config, CLIPEmbedDiffusersConfig):
119+
if not isinstance(config, CLIPEmbed_Diffusers_Config_Base):
120120
raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.")
121121

122122
match submodel_type:
@@ -139,7 +139,7 @@ def _load_model(
139139
config: AnyModelConfig,
140140
submodel_type: Optional[SubModelType] = None,
141141
) -> AnyModel:
142-
if not isinstance(config, T5EncoderBnbQuantizedLlmInt8bConfig):
142+
if not isinstance(config, T5Encoder_BnBLLMint8_Config):
143143
raise ValueError("Only T5EncoderBnbQuantizedLlmInt8bConfig models are currently supported here.")
144144
if not bnb_available:
145145
raise ImportError(
@@ -186,7 +186,7 @@ def _load_model(
186186
config: AnyModelConfig,
187187
submodel_type: Optional[SubModelType] = None,
188188
) -> AnyModel:
189-
if not isinstance(config, T5EncoderConfig):
189+
if not isinstance(config, T5Encoder_T5Encoder_Config):
190190
raise ValueError("Only T5EncoderConfig models are currently supported here.")
191191

192192
match submodel_type:
@@ -226,7 +226,7 @@ def _load_from_singlefile(
226226
self,
227227
config: AnyModelConfig,
228228
) -> AnyModel:
229-
assert isinstance(config, FLUX_Unquantized_CheckpointConfig)
229+
assert isinstance(config, Main_FLUX_Checkpoint_Config)
230230
model_path = Path(config.path)
231231

232232
with accelerate.init_empty_weights():
@@ -268,7 +268,7 @@ def _load_from_singlefile(
268268
self,
269269
config: AnyModelConfig,
270270
) -> AnyModel:
271-
assert isinstance(config, FLUX_Quantized_GGUF_CheckpointConfig)
271+
assert isinstance(config, Main_FLUX_GGUF_Config)
272272
model_path = Path(config.path)
273273

274274
with accelerate.init_empty_weights():
@@ -314,7 +314,7 @@ def _load_from_singlefile(
314314
self,
315315
config: AnyModelConfig,
316316
) -> AnyModel:
317-
assert isinstance(config, FLUX_Quantized_BnB_NF4_CheckpointConfig)
317+
assert isinstance(config, Main_FLUX_BnBNF4_Config)
318318
if not bnb_available:
319319
raise ImportError(
320320
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
@@ -342,9 +342,9 @@ def _load_model(
342342
config: AnyModelConfig,
343343
submodel_type: Optional[SubModelType] = None,
344344
) -> AnyModel:
345-
if isinstance(config, ControlNetCheckpointConfig):
345+
if isinstance(config, ControlNet_Checkpoint_Config_Base):
346346
model_path = Path(config.path)
347-
elif isinstance(config, ControlNetDiffusersConfig):
347+
elif isinstance(config, ControlNet_Diffusers_Config_Base):
348348
# If this is a diffusers directory, we simply ignore the config file and load from the weight file.
349349
model_path = Path(config.path) / "diffusion_pytorch_model.safetensors"
350350
else:
@@ -363,7 +363,7 @@ def _load_model(
363363
def _load_xlabs_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel:
364364
with accelerate.init_empty_weights():
365365
# HACK(ryand): Is it safe to assume dev here?
366-
model = XLabsControlNetFlux(get_flux_transformers_params(ModelVariantType.FluxDev))
366+
model = XLabsControlNetFlux(get_flux_transformers_params(FluxVariantType.Dev))
367367

368368
model.load_state_dict(sd, assign=True)
369369
return model
@@ -389,7 +389,7 @@ def _load_model(
389389
config: AnyModelConfig,
390390
submodel_type: Optional[SubModelType] = None,
391391
) -> AnyModel:
392-
if not isinstance(config, IPAdapterCheckpointConfig):
392+
if not isinstance(config, IPAdapter_Checkpoint_Config_Base):
393393
raise ValueError(f"Unexpected model config type: {type(config)}.")
394394

395395
sd = load_file(Path(config.path))
@@ -412,7 +412,7 @@ def _load_model(
412412
config: AnyModelConfig,
413413
submodel_type: Optional[SubModelType] = None,
414414
) -> AnyModel:
415-
if not isinstance(config, FluxReduxConfig):
415+
if not isinstance(config, FLUXRedux_Checkpoint_Config):
416416
raise ValueError(f"Unexpected model config type: {type(config)}.")
417417

418418
sd = load_file(Path(config.path))

invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,14 @@
1515
AnyModelConfig,
1616
CheckpointConfigBase,
1717
DiffusersConfigBase,
18-
SD_1_2_XL_XLRefiner_CheckpointConfig,
19-
SD_1_2_XL_XLRefiner_DiffusersConfig,
18+
Main_SD1_Checkpoint_Config,
19+
Main_SD1_Diffusers_Config,
20+
Main_SD2_Checkpoint_Config,
21+
Main_SD2_Diffusers_Config,
22+
Main_SDXL_Checkpoint_Config,
23+
Main_SDXL_Diffusers_Config,
24+
Main_SDXLRefiner_Checkpoint_Config,
25+
Main_SDXLRefiner_Diffusers_Config,
2026
)
2127
from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key
2228
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
@@ -108,7 +114,19 @@ def _load_from_singlefile(
108114
ModelVariantType.Normal: StableDiffusionXLPipeline,
109115
},
110116
}
111-
assert isinstance(config, (SD_1_2_XL_XLRefiner_DiffusersConfig, SD_1_2_XL_XLRefiner_CheckpointConfig))
117+
assert isinstance(
118+
config,
119+
(
120+
Main_SD1_Diffusers_Config,
121+
Main_SD2_Diffusers_Config,
122+
Main_SDXL_Diffusers_Config,
123+
Main_SDXLRefiner_Diffusers_Config,
124+
Main_SD1_Checkpoint_Config,
125+
Main_SD2_Checkpoint_Config,
126+
Main_SDXL_Checkpoint_Config,
127+
Main_SDXLRefiner_Checkpoint_Config,
128+
),
129+
)
112130
try:
113131
load_class = load_classes[config.base][config.variant]
114132
except KeyError as e:

invokeai/backend/model_manager/load/model_loaders/vae.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
from typing import Optional
55

6-
from diffusers import AutoencoderKL
6+
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
77

8-
from invokeai.backend.model_manager.config import AnyModelConfig, VAECheckpointConfig
8+
from invokeai.backend.model_manager.config import AnyModelConfig, VAE_Checkpoint_Config_Base
99
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
1010
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
1111
from invokeai.backend.model_manager.taxonomy import (
@@ -27,7 +27,7 @@ def _load_model(
2727
config: AnyModelConfig,
2828
submodel_type: Optional[SubModelType] = None,
2929
) -> AnyModel:
30-
if isinstance(config, VAECheckpointConfig):
30+
if isinstance(config, VAE_Checkpoint_Config_Base):
3131
return AutoencoderKL.from_single_file(
3232
config.path,
3333
torch_dtype=self._torch_dtype,

tests/app/services/model_records/test_model_records_sql.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
ControlAdapterDefaultSettings,
2222
MainDiffusersConfig,
2323
MainModelDefaultSettings,
24-
TextualInversionFileConfig,
24+
TI_File_Config,
2525
VAEDiffusersConfig,
2626
)
2727
from invokeai.backend.model_manager.taxonomy import ModelSourceType
@@ -40,8 +40,8 @@ def store(
4040
return ModelRecordServiceSQL(db, logger)
4141

4242

43-
def example_ti_config(key: Optional[str] = None) -> TextualInversionFileConfig:
44-
config = TextualInversionFileConfig(
43+
def example_ti_config(key: Optional[str] = None) -> TI_File_Config:
44+
config = TI_File_Config(
4545
source="test/source/",
4646
source_type=ModelSourceType.Path,
4747
path="/tmp/pokemon.bin",
@@ -61,7 +61,7 @@ def test_type(store: ModelRecordServiceBase):
6161
config = example_ti_config("key1")
6262
store.add_model(config)
6363
config1 = store.get_model("key1")
64-
assert isinstance(config1, TextualInversionFileConfig)
64+
assert isinstance(config1, TI_File_Config)
6565

6666

6767
def test_raises_on_violating_uniqueness(store: ModelRecordServiceBase):

0 commit comments

Comments
 (0)