1
- from typing import List , Tuple
1
+ from typing import Callable , List , Tuple
2
2
3
3
import torch
4
4
from diffusers .models .autoencoders .autoencoder_kl import AutoencoderKL
5
5
from diffusers .schedulers .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
6
6
7
7
from invokeai .app .invocations .bria_controlnet import BriaControlNetField
8
- from invokeai .app .invocations .fields import Input , InputField , LatentsField , OutputField
8
+ from invokeai .app .invocations .bria_latent_noise import BriaLatentNoiseOutput
9
+ from invokeai .app .invocations .fields import FluxConditioningField , Input , InputField , LatentsField , OutputField
9
10
from invokeai .app .invocations .model import SubModelType , T5EncoderField , TransformerField , VAEField
10
11
from invokeai .app .invocations .primitives import BaseInvocationOutput , FieldDescriptions
11
12
from invokeai .app .services .shared .invocation_context import InvocationContext
12
13
from invokeai .backend .bria .controlnet_bria import BriaControlModes , BriaMultiControlNetModel
13
14
from invokeai .backend .bria .controlnet_utils import prepare_control_images
14
15
from invokeai .backend .bria .pipeline_bria_controlnet import BriaControlNetPipeline
15
16
from invokeai .backend .bria .transformer_bria import BriaTransformer2DModel
17
+ from invokeai .backend .model_manager .taxonomy import BaseModelType
18
+ from invokeai .backend .stable_diffusion .extensions .preview import PipelineIntermediateState
16
19
from invokeai .invocation_api import BaseInvocation , Classification , invocation , invocation_output
17
20
18
21
@@ -30,6 +33,11 @@ class BriaDenoiseInvocationOutput(BaseInvocationOutput):
30
33
classification = Classification .Prototype ,
31
34
)
32
35
class BriaDenoiseInvocation (BaseInvocation ):
36
+
37
+ """
38
+ Denoise Bria latents using a Bria Pipeline.
39
+ """
40
+
33
41
num_steps : int = InputField (
34
42
default = 30 , title = "Number of Steps" , description = "The number of steps to use for the denoiser"
35
43
)
@@ -52,31 +60,31 @@ class BriaDenoiseInvocation(BaseInvocation):
52
60
input = Input .Connection ,
53
61
title = "VAE" ,
54
62
)
55
- latents : LatentsField = InputField (
56
- description = "Latents to denoise" ,
57
- input = Input . Connection ,
58
- title = "Latents " ,
63
+ height : int = InputField (
64
+ default = 1024 ,
65
+ title = "Height" ,
66
+ description = "The height of the output image " ,
59
67
)
60
- latent_image_ids : LatentsField = InputField (
61
- description = "Latent Image IDs to denoise" ,
68
+ width : int = InputField (
69
+ default = 1024 ,
70
+ title = "Width" ,
71
+ description = "The width of the output image" ,
72
+ )
73
+ latent_noise : BriaLatentNoiseOutput = InputField (
74
+ description = "Latent noise to denoise" ,
62
75
input = Input .Connection ,
63
- title = "Latent Image IDs " ,
76
+ title = "Latent Noise " ,
64
77
)
65
- pos_embeds : LatentsField = InputField (
78
+ pos_embeds : FluxConditioningField = InputField (
66
79
description = "Positive Prompt Embeds" ,
67
80
input = Input .Connection ,
68
81
title = "Positive Prompt Embeds" ,
69
82
)
70
- neg_embeds : LatentsField = InputField (
83
+ neg_embeds : FluxConditioningField = InputField (
71
84
description = "Negative Prompt Embeds" ,
72
85
input = Input .Connection ,
73
86
title = "Negative Prompt Embeds" ,
74
87
)
75
- text_ids : LatentsField = InputField (
76
- description = "Text IDs" ,
77
- input = Input .Connection ,
78
- title = "Text IDs" ,
79
- )
80
88
control : BriaControlNetField | list [BriaControlNetField ] | None = InputField (
81
89
description = "ControlNet" ,
82
90
input = Input .Connection ,
@@ -86,11 +94,10 @@ class BriaDenoiseInvocation(BaseInvocation):
86
94
87
95
@torch .no_grad ()
88
96
def invoke (self , context : InvocationContext ) -> BriaDenoiseInvocationOutput :
89
- latents = context .tensors .load (self .latents .latents_name )
90
- pos_embeds = context .tensors .load (self .pos_embeds .latents_name )
91
- neg_embeds = context .tensors .load (self .neg_embeds .latents_name )
92
- text_ids = context .tensors .load (self .text_ids .latents_name )
93
- latent_image_ids = context .tensors .load (self .latent_image_ids .latents_name )
97
+ latents = context .tensors .load (self .latent_noise .latents .latents_name )
98
+ pos_embeds = context .tensors .load (self .pos_embeds .conditioning_name )
99
+ neg_embeds = context .tensors .load (self .neg_embeds .conditioning_name )
100
+ latent_image_ids = context .tensors .load (self .latent_noise .latent_image_ids .latents_name )
94
101
scheduler_identifier = self .transformer .transformer .model_copy (update = {"submodel_type" : SubModelType .Scheduler })
95
102
96
103
device = None
@@ -114,11 +121,12 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
114
121
control_model , control_images , control_modes , control_scales = self ._prepare_multi_control (
115
122
context = context ,
116
123
vae = vae ,
117
- width = 1024 ,
118
- height = 1024 ,
124
+ width = self . width ,
125
+ height = self . height ,
119
126
device = vae .device ,
120
127
)
121
128
129
+
122
130
pipeline = BriaControlNetPipeline (
123
131
transformer = transformer ,
124
132
scheduler = scheduler ,
@@ -129,31 +137,32 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
129
137
)
130
138
pipeline .to (device = transformer .device , dtype = transformer .dtype )
131
139
132
- latents = pipeline (
140
+ output_latents = pipeline (
133
141
control_image = control_images ,
134
142
control_mode = control_modes ,
135
- width = 1024 ,
136
- height = 1024 ,
143
+ width = self . width ,
144
+ height = self . height ,
137
145
controlnet_conditioning_scale = control_scales ,
138
146
num_inference_steps = self .num_steps ,
139
147
max_sequence_length = 128 ,
140
148
guidance_scale = self .guidance_scale ,
141
149
latents = latents ,
142
150
latent_image_ids = latent_image_ids ,
143
- text_ids = text_ids ,
144
151
prompt_embeds = pos_embeds ,
145
152
negative_prompt_embeds = neg_embeds ,
146
153
output_type = "latent" ,
154
+ step_callback = _build_step_callback (context ),
147
155
)[0 ]
148
156
149
- assert isinstance (latents , torch .Tensor )
150
- saved_input_latents_tensor = context .tensors .save (latents )
151
- latents_output = LatentsField (latents_name = saved_input_latents_tensor )
152
- return BriaDenoiseInvocationOutput (latents = latents_output )
157
+
158
+
159
+ assert isinstance (output_latents , torch .Tensor )
160
+ saved_input_latents_tensor = context .tensors .save (output_latents )
161
+ return BriaDenoiseInvocationOutput (latents = LatentsField (latents_name = saved_input_latents_tensor ))
153
162
154
163
def _prepare_multi_control (
155
164
self , context : InvocationContext , vae : AutoencoderKL , width : int , height : int , device : torch .device
156
- ) -> Tuple [BriaMultiControlNetModel , List [torch .Tensor ], List [torch . Tensor ], List [float ]]:
165
+ ) -> Tuple [BriaMultiControlNetModel , List [torch .Tensor ], List [int ], List [float ]]:
157
166
control = self .control if isinstance (self .control , list ) else [self .control ]
158
167
control_images , control_models , control_modes , control_scales = [], [], [], []
159
168
for controlnet in control :
@@ -178,3 +187,11 @@ def _prepare_multi_control(
178
187
device = device ,
179
188
)
180
189
return control_model , tensored_control_images , tensored_control_modes , control_scales
190
+
191
+
192
+ def _build_step_callback (context : InvocationContext ) -> Callable [[PipelineIntermediateState ], None ]:
193
+ def step_callback (state : PipelineIntermediateState ) -> None :
194
+ return
195
+ context .util .sd_step_callback (state , BaseModelType .Bria )
196
+
197
+ return step_callback
0 commit comments