Skip to content

Commit 89bda5b

Browse files
Ryan/sd3 diffusers (#7222)
## Summary Nodes to support SD3.5 txt2img generations * adds SD3.5 to starter models * adds default workflow for SD3.5 txt2img ## Related Issues / Discussions <!--WHEN APPLICABLE: List any related issues or discussions on github or discord. If this PR closes an issue, please use the "Closes #1234" format, so that the issue will be automatically closed when the PR merges.--> ## QA Instructions <!--WHEN APPLICABLE: Describe how you have tested the changes in this PR. Provide enough detail that a reviewer can reproduce your tests.--> ## Merge Plan <!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like DB schemas, may need some care when merging. For example, a careful rebase by the change author, timing to not interfere with a pending release, or a message to contributors on discord after merging.--> ## Checklist - [ ] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
2 parents 0f11fda + 22bff1f commit 89bda5b

File tree

41 files changed

+2523
-171
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2523
-171
lines changed

invokeai/app/invocations/fields.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
4141
# region Model Field Types
4242
MainModel = "MainModelField"
4343
FluxMainModel = "FluxMainModelField"
44+
SD3MainModel = "SD3MainModelField"
4445
SDXLMainModel = "SDXLMainModelField"
4546
SDXLRefinerModel = "SDXLRefinerModelField"
4647
ONNXModel = "ONNXModelField"
@@ -52,6 +53,8 @@ class UIType(str, Enum, metaclass=MetaEnum):
5253
T2IAdapterModel = "T2IAdapterModelField"
5354
T5EncoderModel = "T5EncoderModelField"
5455
CLIPEmbedModel = "CLIPEmbedModelField"
56+
CLIPLEmbedModel = "CLIPLEmbedModelField"
57+
CLIPGEmbedModel = "CLIPGEmbedModelField"
5558
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
5659
# endregion
5760

@@ -131,15 +134,18 @@ class FieldDescriptions:
131134
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
132135
t5_encoder = "T5 tokenizer and text encoder"
133136
clip_embed_model = "CLIP Embed loader"
137+
clip_g_model = "CLIP-G Embed loader"
134138
unet = "UNet (scheduler, LoRAs)"
135139
transformer = "Transformer"
140+
mmditx = "MMDiTX"
136141
vae = "VAE"
137142
cond = "Conditioning tensor"
138143
controlnet_model = "ControlNet model to load"
139144
vae_model = "VAE model to load"
140145
lora_model = "LoRA model to load"
141146
main_model = "Main model (UNet, VAE, CLIP) to load"
142147
flux_model = "Flux model (Transformer) to load"
148+
sd3_model = "SD3 model (MMDiTX) to load"
143149
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
144150
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
145151
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
@@ -246,6 +252,12 @@ class FluxConditioningField(BaseModel):
246252
conditioning_name: str = Field(description="The name of conditioning tensor")
247253

248254

255+
class SD3ConditioningField(BaseModel):
256+
"""A conditioning tensor primitive value"""
257+
258+
conditioning_name: str = Field(description="The name of conditioning tensor")
259+
260+
249261
class ConditioningField(BaseModel):
250262
"""A conditioning tensor primitive value"""
251263

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from typing import Literal
2+
3+
from invokeai.app.invocations.baseinvocation import (
4+
BaseInvocation,
5+
BaseInvocationOutput,
6+
Classification,
7+
invocation,
8+
invocation_output,
9+
)
10+
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
11+
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
12+
from invokeai.app.services.shared.invocation_context import InvocationContext
13+
from invokeai.backend.flux.util import max_seq_lengths
14+
from invokeai.backend.model_manager.config import (
15+
CheckpointConfigBase,
16+
SubModelType,
17+
)
18+
19+
20+
@invocation_output("flux_model_loader_output")
21+
class FluxModelLoaderOutput(BaseInvocationOutput):
22+
"""Flux base model loader output"""
23+
24+
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
25+
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
26+
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
27+
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
28+
max_seq_len: Literal[256, 512] = OutputField(
29+
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
30+
title="Max Seq Length",
31+
)
32+
33+
34+
@invocation(
35+
"flux_model_loader",
36+
title="Flux Main Model",
37+
tags=["model", "flux"],
38+
category="model",
39+
version="1.0.4",
40+
classification=Classification.Prototype,
41+
)
42+
class FluxModelLoaderInvocation(BaseInvocation):
43+
"""Loads a flux base model, outputting its submodels."""
44+
45+
model: ModelIdentifierField = InputField(
46+
description=FieldDescriptions.flux_model,
47+
ui_type=UIType.FluxMainModel,
48+
input=Input.Direct,
49+
)
50+
51+
t5_encoder_model: ModelIdentifierField = InputField(
52+
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
53+
)
54+
55+
clip_embed_model: ModelIdentifierField = InputField(
56+
description=FieldDescriptions.clip_embed_model,
57+
ui_type=UIType.CLIPEmbedModel,
58+
input=Input.Direct,
59+
title="CLIP Embed",
60+
)
61+
62+
vae_model: ModelIdentifierField = InputField(
63+
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
64+
)
65+
66+
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
67+
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
68+
if not context.models.exists(key):
69+
raise ValueError(f"Unknown model: {key}")
70+
71+
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
72+
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
73+
74+
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
75+
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
76+
77+
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
78+
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
79+
80+
transformer_config = context.models.get_config(transformer)
81+
assert isinstance(transformer_config, CheckpointConfigBase)
82+
83+
return FluxModelLoaderOutput(
84+
transformer=TransformerField(transformer=transformer, loras=[]),
85+
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
86+
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
87+
vae=VAEField(vae=vae),
88+
max_seq_len=max_seq_lengths[transformer_config.config_path],
89+
)

invokeai/app/invocations/model.py

Lines changed: 1 addition & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import copy
2-
from typing import List, Literal, Optional
2+
from typing import List, Optional
33

44
from pydantic import BaseModel, Field
55

@@ -13,11 +13,9 @@
1313
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
1414
from invokeai.app.services.shared.invocation_context import InvocationContext
1515
from invokeai.app.shared.models import FreeUConfig
16-
from invokeai.backend.flux.util import max_seq_lengths
1716
from invokeai.backend.model_manager.config import (
1817
AnyModelConfig,
1918
BaseModelType,
20-
CheckpointConfigBase,
2119
ModelType,
2220
SubModelType,
2321
)
@@ -139,78 +137,6 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
139137
return ModelIdentifierOutput(model=self.model)
140138

141139

142-
@invocation_output("flux_model_loader_output")
143-
class FluxModelLoaderOutput(BaseInvocationOutput):
144-
"""Flux base model loader output"""
145-
146-
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
147-
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
148-
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
149-
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
150-
max_seq_len: Literal[256, 512] = OutputField(
151-
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
152-
title="Max Seq Length",
153-
)
154-
155-
156-
@invocation(
157-
"flux_model_loader",
158-
title="Flux Main Model",
159-
tags=["model", "flux"],
160-
category="model",
161-
version="1.0.4",
162-
classification=Classification.Prototype,
163-
)
164-
class FluxModelLoaderInvocation(BaseInvocation):
165-
"""Loads a flux base model, outputting its submodels."""
166-
167-
model: ModelIdentifierField = InputField(
168-
description=FieldDescriptions.flux_model,
169-
ui_type=UIType.FluxMainModel,
170-
input=Input.Direct,
171-
)
172-
173-
t5_encoder_model: ModelIdentifierField = InputField(
174-
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
175-
)
176-
177-
clip_embed_model: ModelIdentifierField = InputField(
178-
description=FieldDescriptions.clip_embed_model,
179-
ui_type=UIType.CLIPEmbedModel,
180-
input=Input.Direct,
181-
title="CLIP Embed",
182-
)
183-
184-
vae_model: ModelIdentifierField = InputField(
185-
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
186-
)
187-
188-
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
189-
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
190-
if not context.models.exists(key):
191-
raise ValueError(f"Unknown model: {key}")
192-
193-
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
194-
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
195-
196-
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
197-
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
198-
199-
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
200-
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
201-
202-
transformer_config = context.models.get_config(transformer)
203-
assert isinstance(transformer_config, CheckpointConfigBase)
204-
205-
return FluxModelLoaderOutput(
206-
transformer=TransformerField(transformer=transformer, loras=[]),
207-
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
208-
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
209-
vae=VAEField(vae=vae),
210-
max_seq_len=max_seq_lengths[transformer_config.config_path],
211-
)
212-
213-
214140
@invocation(
215141
"main_model_loader",
216142
title="Main Model",

invokeai/app/invocations/primitives.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
InputField,
1919
LatentsField,
2020
OutputField,
21+
SD3ConditioningField,
2122
TensorField,
2223
UIComponent,
2324
)
@@ -426,6 +427,17 @@ def build(cls, conditioning_name: str) -> "FluxConditioningOutput":
426427
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
427428

428429

430+
@invocation_output("sd3_conditioning_output")
431+
class SD3ConditioningOutput(BaseInvocationOutput):
432+
"""Base class for nodes that output a single SD3 conditioning tensor"""
433+
434+
conditioning: SD3ConditioningField = OutputField(description=FieldDescriptions.cond)
435+
436+
@classmethod
437+
def build(cls, conditioning_name: str) -> "SD3ConditioningOutput":
438+
return cls(conditioning=SD3ConditioningField(conditioning_name=conditioning_name))
439+
440+
429441
@invocation_output("conditioning_output")
430442
class ConditioningOutput(BaseInvocationOutput):
431443
"""Base class for nodes that output a single conditioning tensor"""

0 commit comments

Comments
 (0)