Skip to content

Commit 1eee328

Browse files
feat(mm): support bria-3 controlnets
1 parent feecfe5 commit 1eee328

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,34 @@ def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
547547
def get_tag(cls) -> Tag:
548548
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}.{BaseModelType.Bria.value}")
549549

550+
class BriaControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfigBase):
551+
"""Model config for Bria/Diffusers ControlNet models."""
552+
553+
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
554+
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
555+
base: Literal[BaseModelType.Bria] = BaseModelType.Bria
556+
557+
@classmethod
558+
def matches(cls, mod: ModelOnDisk) -> bool:
559+
if mod.path.is_file():
560+
return False
561+
562+
config_path = mod.path / "config.json"
563+
if config_path.exists():
564+
with open(config_path) as file:
565+
transformer_conf = json.load(file)
566+
if transformer_conf["_class_name"] == "BriaTransformer2DModel":
567+
return True
568+
569+
return False
570+
571+
@classmethod
572+
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
573+
return {}
574+
575+
@classmethod
576+
def get_tag(cls) -> Tag:
577+
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}.{BaseModelType.Bria.value}")
550578

551579

552580
class IPAdapterConfigBase(ABC, BaseModel):
@@ -732,6 +760,7 @@ def get_model_discriminator_value(v: Any) -> str:
732760
Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()],
733761
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
734762
Annotated[BriaDiffusersConfig, BriaDiffusersConfig.get_tag()],
763+
Annotated[BriaControlNetDiffusersConfig, BriaControlNetDiffusersConfig.get_tag()],
735764
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
736765
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
737766
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],

invokeai/backend/model_manager/legacy_probe.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,6 @@ class ModelProbe(object):
125125
}
126126

127127
CLASS2TYPE = {
128-
"BriaPipeline": ModelType.Main,
129-
"BriaTransformer2DModel": ModelType.ControlNet,
130128
"FluxPipeline": ModelType.Main,
131129
"StableDiffusionPipeline": ModelType.Main,
132130
"StableDiffusionInpaintPipeline": ModelType.Main,
@@ -863,8 +861,6 @@ def get_base_type(self) -> BaseModelType:
863861
return BaseModelType.StableDiffusion3
864862
elif transformer_conf["_class_name"] == "CogView4Transformer2DModel":
865863
return BaseModelType.CogView4
866-
elif transformer_conf["_class_name"] == "BriaTransformer2DModel":
867-
return BaseModelType.Bria
868864
else:
869865
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
870866

@@ -1014,9 +1010,6 @@ def get_base_type(self) -> BaseModelType:
10141010
if config.get("_class_name", None) == "FluxControlNetModel":
10151011
return BaseModelType.Flux
10161012

1017-
if config.get("_class_name", None) == "BriaTransformer2DModel":
1018-
return BaseModelType.Bria
1019-
10201013
# no obvious way to distinguish between sd2-base and sd2-768
10211014
dimension = config["cross_attention_dim"]
10221015
if dimension == 768:

0 commit comments

Comments
 (0)