@@ -406,16 +406,94 @@ def base_model(cls, mod: ModelOnDisk) -> BaseModelType:
406
406
class T5EncoderConfigBase (ABC , BaseModel ):
407
407
"""Base class for diffusers-style models."""
408
408
409
+ base : Literal [BaseModelType .Any ] = BaseModelType .Any
409
410
type : Literal [ModelType .T5Encoder ] = ModelType .T5Encoder
410
411
412
+ @classmethod
413
+ def get_config (cls , mod : ModelOnDisk ) -> dict [str , Any ]:
414
+ path = mod .path / "text_encoder_2" / "config.json"
415
+ with open (path , "r" ) as file :
416
+ return json .load (file )
417
+
418
+ @classmethod
419
+ def parse (cls , mod : ModelOnDisk ) -> dict [str , Any ]:
420
+ return {}
421
+
411
422
412
- class T5EncoderConfig (T5EncoderConfigBase , LegacyProbeMixin , ModelConfigBase ):
423
+ class T5EncoderConfig (T5EncoderConfigBase , ModelConfigBase ):
413
424
format : Literal [ModelFormat .T5Encoder ] = ModelFormat .T5Encoder
414
425
426
+ @classmethod
427
+ def matches (cls , mod : ModelOnDisk , ** overrides ) -> MatchCertainty :
428
+ is_t5_type_override = overrides .get ("type" ) is ModelType .T5Encoder
429
+ is_t5_format_override = overrides .get ("format" ) is ModelFormat .T5Encoder
430
+
431
+ if is_t5_type_override and is_t5_format_override :
432
+ return MatchCertainty .OVERRIDE
433
+
434
+ if mod .path .is_file ():
435
+ return MatchCertainty .NEVER
436
+
437
+ model_dir = mod .path / "text_encoder_2"
438
+
439
+ if not model_dir .exists ():
440
+ return MatchCertainty .NEVER
441
+
442
+ try :
443
+ config = cls .get_config (mod )
444
+
445
+ is_t5_encoder_model = get_class_name_from_config (config ) == "T5EncoderModel"
446
+ is_t5_format = (model_dir / "model.safetensors.index.json" ).exists ()
415
447
416
- class T5EncoderBnbQuantizedLlmInt8bConfig (T5EncoderConfigBase , LegacyProbeMixin , ModelConfigBase ):
448
+ if is_t5_encoder_model and is_t5_format :
449
+ return MatchCertainty .EXACT
450
+ except Exception :
451
+ pass
452
+
453
+ return MatchCertainty .NEVER
454
+
455
+
456
+ class T5EncoderBnbQuantizedLlmInt8bConfig (T5EncoderConfigBase , ModelConfigBase ):
417
457
format : Literal [ModelFormat .BnbQuantizedLlmInt8b ] = ModelFormat .BnbQuantizedLlmInt8b
418
458
459
+ @classmethod
460
+ def matches (cls , mod : ModelOnDisk , ** overrides ) -> MatchCertainty :
461
+ is_t5_type_override = overrides .get ("type" ) is ModelType .T5Encoder
462
+ is_bnb_format_override = overrides .get ("format" ) is ModelFormat .BnbQuantizedLlmInt8b
463
+
464
+ if is_t5_type_override and is_bnb_format_override :
465
+ return MatchCertainty .OVERRIDE
466
+
467
+ if mod .path .is_file ():
468
+ return MatchCertainty .NEVER
469
+
470
+ model_dir = mod .path / "text_encoder_2"
471
+
472
+ if not model_dir .exists ():
473
+ return MatchCertainty .NEVER
474
+
475
+ try :
476
+ config = cls .get_config (mod )
477
+
478
+ is_t5_encoder_model = get_class_name_from_config (config ) == "T5EncoderModel"
479
+
480
+ # Heuristic: look for the quantization in the name
481
+ files = model_dir .glob ("*.safetensors" )
482
+ filename_looks_like_bnb = any (x for x in files if "llm_int8" in x .as_posix ())
483
+
484
+ if is_t5_encoder_model and filename_looks_like_bnb :
485
+ return MatchCertainty .EXACT
486
+
487
+ # Heuristic: Look for the presence of "SCB" in state dict keys (typically a suffix)
488
+ has_scb_key = mod .has_keys_ending_with ("SCB" )
489
+
490
+ if is_t5_encoder_model and has_scb_key :
491
+ return MatchCertainty .EXACT
492
+ except Exception :
493
+ pass
494
+
495
+ return MatchCertainty .NEVER
496
+
419
497
420
498
class LoRAOmiConfig (LoRAConfigBase , ModelConfigBase ):
421
499
format : Literal [ModelFormat .OMI ] = ModelFormat .OMI
0 commit comments