Skip to content

Commit 8217fd9

Browse files
feat(mm): port t5 to new API
1 parent 5996e31 commit 8217fd9

File tree

3 files changed

+89
-33
lines changed

3 files changed

+89
-33
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,16 +406,94 @@ def base_model(cls, mod: ModelOnDisk) -> BaseModelType:
406406
class T5EncoderConfigBase(ABC, BaseModel):
407407
"""Base class for diffusers-style models."""
408408

409+
base: Literal[BaseModelType.Any] = BaseModelType.Any
409410
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
410411

412+
@classmethod
413+
def get_config(cls, mod: ModelOnDisk) -> dict[str, Any]:
414+
path = mod.path / "text_encoder_2" / "config.json"
415+
with open(path, "r") as file:
416+
return json.load(file)
417+
418+
@classmethod
419+
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
420+
return {}
421+
411422

412-
class T5EncoderConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase):
423+
class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
413424
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
414425

426+
@classmethod
427+
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
428+
is_t5_type_override = overrides.get("type") is ModelType.T5Encoder
429+
is_t5_format_override = overrides.get("format") is ModelFormat.T5Encoder
430+
431+
if is_t5_type_override and is_t5_format_override:
432+
return MatchCertainty.OVERRIDE
433+
434+
if mod.path.is_file():
435+
return MatchCertainty.NEVER
436+
437+
model_dir = mod.path / "text_encoder_2"
438+
439+
if not model_dir.exists():
440+
return MatchCertainty.NEVER
441+
442+
try:
443+
config = cls.get_config(mod)
444+
445+
is_t5_encoder_model = get_class_name_from_config(config) == "T5EncoderModel"
446+
is_t5_format = (model_dir / "model.safetensors.index.json").exists()
415447

416-
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase):
448+
if is_t5_encoder_model and is_t5_format:
449+
return MatchCertainty.EXACT
450+
except Exception:
451+
pass
452+
453+
return MatchCertainty.NEVER
454+
455+
456+
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
417457
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
418458

459+
@classmethod
460+
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
461+
is_t5_type_override = overrides.get("type") is ModelType.T5Encoder
462+
is_bnb_format_override = overrides.get("format") is ModelFormat.BnbQuantizedLlmInt8b
463+
464+
if is_t5_type_override and is_bnb_format_override:
465+
return MatchCertainty.OVERRIDE
466+
467+
if mod.path.is_file():
468+
return MatchCertainty.NEVER
469+
470+
model_dir = mod.path / "text_encoder_2"
471+
472+
if not model_dir.exists():
473+
return MatchCertainty.NEVER
474+
475+
try:
476+
config = cls.get_config(mod)
477+
478+
is_t5_encoder_model = get_class_name_from_config(config) == "T5EncoderModel"
479+
480+
# Heuristic: look for the quantization in the name
481+
files = model_dir.glob("*.safetensors")
482+
filename_looks_like_bnb = any(x for x in files if "llm_int8" in x.as_posix())
483+
484+
if is_t5_encoder_model and filename_looks_like_bnb:
485+
return MatchCertainty.EXACT
486+
487+
# Heuristic: Look for the presence of "SCB" in state dict keys (typically a suffix)
488+
has_scb_key = mod.has_keys_ending_with("SCB")
489+
490+
if is_t5_encoder_model and has_scb_key:
491+
return MatchCertainty.EXACT
492+
except Exception:
493+
pass
494+
495+
return MatchCertainty.NEVER
496+
419497

420498
class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
421499
format: Literal[ModelFormat.OMI] = ModelFormat.OMI

invokeai/backend/model_manager/legacy_probe.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -879,30 +879,6 @@ def get_variant_type(self) -> ModelVariantType:
879879
return ModelVariantType.Normal
880880

881881

882-
class T5EncoderFolderProbe(FolderProbeBase):
883-
def get_base_type(self) -> BaseModelType:
884-
return BaseModelType.Any
885-
886-
def get_format(self) -> ModelFormat:
887-
path = self.model_path / "text_encoder_2"
888-
if (path / "model.safetensors.index.json").exists():
889-
return ModelFormat.T5Encoder
890-
files = list(path.glob("*.safetensors"))
891-
if len(files) == 0:
892-
raise InvalidModelConfigException(f"{self.model_path.as_posix()}: no .safetensors files found")
893-
894-
# shortcut: look for the quantization in the name
895-
if any(x for x in files if "llm_int8" in x.as_posix()):
896-
return ModelFormat.BnbQuantizedLlmInt8b
897-
898-
# more reliable path: probe contents for a 'SCB' key
899-
ckpt = read_checkpoint_meta(files[0], scan=True)
900-
if any("SCB" in x for x in ckpt.keys()):
901-
return ModelFormat.BnbQuantizedLlmInt8b
902-
903-
raise InvalidModelConfigException(f"{self.model_path.as_posix()}: unknown model format")
904-
905-
906882
class ONNXFolderProbe(PipelineFolderProbe):
907883
def get_base_type(self) -> BaseModelType:
908884
# Due to the way the installer is set up, the configuration file for safetensors
@@ -1036,7 +1012,6 @@ def get_base_type(self) -> BaseModelType:
10361012
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
10371013
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
10381014
ModelProbe.register_probe("diffusers", ModelType.ControlLoRa, LoRAFolderProbe)
1039-
ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe)
10401015
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
10411016
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
10421017
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)

invokeai/backend/model_manager/model_on_disk.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,18 +129,21 @@ def resolve_weight_file(self, path: Optional[Path] = None) -> Path:
129129
)
130130
return path
131131

132-
def has_keys_exact(self, keys: set[str], path: Optional[Path] = None) -> bool:
132+
def has_keys_exact(self, keys: str | set[str], path: Optional[Path] = None) -> bool:
133+
_keys = {keys} if isinstance(keys, str) else keys
133134
state_dict = self.load_state_dict(path)
134-
return keys.issubset({key for key in state_dict.keys() if isinstance(key, str)})
135+
return _keys.issubset({key for key in state_dict.keys() if isinstance(key, str)})
135136

136-
def has_keys_starting_with(self, prefixes: set[str], path: Optional[Path] = None) -> bool:
137+
def has_keys_starting_with(self, prefixes: str | set[str], path: Optional[Path] = None) -> bool:
138+
_prefixes = {prefixes} if isinstance(prefixes, str) else prefixes
137139
state_dict = self.load_state_dict(path)
138140
return any(
139-
any(key.startswith(prefix) for prefix in prefixes) for key in state_dict.keys() if isinstance(key, str)
141+
any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str)
140142
)
141143

142-
def has_keys_ending_with(self, prefixes: set[str], path: Optional[Path] = None) -> bool:
144+
def has_keys_ending_with(self, suffixes: str | set[str], path: Optional[Path] = None) -> bool:
145+
_suffixes = {suffixes} if isinstance(suffixes, str) else suffixes
143146
state_dict = self.load_state_dict(path)
144147
return any(
145-
any(key.endswith(suffix) for suffix in prefixes) for key in state_dict.keys() if isinstance(key, str)
148+
any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str)
146149
)

0 commit comments

Comments
 (0)