Skip to content

Commit 6566ef7

Browse files
ilanbriapsychedelicious
authored andcommitted
cr fixes 1
1 parent a6feefc commit 6566ef7

25 files changed

+169
-2742
lines changed

invokeai/app/invocations/bria_controlnet.py

Lines changed: 6 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import cv2
2-
import numpy as np
31
from PIL import Image
42
from pydantic import BaseModel, Field
53

@@ -20,9 +18,7 @@
2018
)
2119
from invokeai.app.invocations.model import ModelIdentifierField
2220
from invokeai.app.services.shared.invocation_context import InvocationContext
23-
from invokeai.backend.bria.controlnet_aux.open_pose import Body, Face, Hand, OpenposeDetector
2421
from invokeai.backend.bria.controlnet_bria import BRIA_CONTROL_MODES
25-
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
2622
from invokeai.invocation_api import Classification, ImageOutput
2723

2824
DEPTH_SMALL_V2_URL = "depth-anything/Depth-Anything-V2-Small-hf"
@@ -41,7 +37,6 @@ class BriaControlNetOutput(BaseInvocationOutput):
4137
"""Bria ControlNet info"""
4238

4339
control: BriaControlNetField = OutputField(description=FieldDescriptions.control)
44-
preprocessed_images: ImageField = OutputField(description="The preprocessed control image")
4540

4641

4742
@invocation(
@@ -64,24 +59,18 @@ class BriaControlNetInvocation(BaseInvocation, WithMetadata, WithBoard):
6459

6560
def invoke(self, context: InvocationContext) -> BriaControlNetOutput:
6661
image_in = resize_img(context.images.get_pil(self.control_image.image_name))
67-
if self.control_mode == "canny":
68-
control_image = extract_canny(image_in)
69-
elif self.control_mode == "depth":
70-
control_image = extract_depth(image_in, context)
71-
elif self.control_mode == "pose":
72-
control_image = extract_openpose(image_in, context)
73-
elif self.control_mode == "colorgrid":
62+
if self.control_mode == "colorgrid":
7463
control_image = tile(64, image_in)
7564
elif self.control_mode == "recolor":
7665
control_image = convert_to_grayscale(image_in)
7766
elif self.control_mode == "tile":
7867
control_image = tile(16, image_in)
68+
else:
69+
control_image = image_in
7970

8071
control_image = resize_img(control_image)
8172
image_dto = context.images.save(image=control_image)
82-
image_output = ImageOutput.build(image_dto)
8373
return BriaControlNetOutput(
84-
preprocessed_images=image_output.image,
8574
control=BriaControlNetField(
8675
image=ImageField(image_name=image_dto.image_name),
8776
model=self.control_model,
@@ -106,50 +95,20 @@ def invoke(self, context: InvocationContext) -> BriaControlNetOutput:
10695
}
10796

10897

109-
def extract_depth(image: Image.Image, context: InvocationContext):
110-
loaded_model = context.models.load_remote_model(DEPTH_SMALL_V2_URL, DepthAnythingPipeline.load_model)
11198

112-
with loaded_model as depth_anything_detector:
113-
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
114-
depth_map = depth_anything_detector.generate_depth(image)
115-
return depth_map
116-
117-
118-
def extract_openpose(image: Image.Image, context: InvocationContext):
119-
body_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}body_pose_model.pth", Body)
120-
hand_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}hand_pose_model.pth", Hand)
121-
face_model = context.models.load_remote_model(f"{HF_LLLYASVIEL}facenet.pth", Face)
122-
123-
with body_model as body_model, hand_model as hand_model, face_model as face_model:
124-
open_pose_model = OpenposeDetector(body_model, hand_model, face_model)
125-
processed_image_open_pose = open_pose_model(image, hand_and_face=True)
126-
127-
processed_image_open_pose = processed_image_open_pose.resize(image.size)
128-
return processed_image_open_pose
129-
130-
131-
def extract_canny(input_image):
132-
image = np.array(input_image)
133-
image = cv2.Canny(image, 100, 200)
134-
image = image[:, :, None]
135-
image = np.concatenate([image, image, image], axis=2)
136-
canny_image = Image.fromarray(image)
137-
return canny_image
138-
139-
140-
def convert_to_grayscale(image):
99+
def convert_to_grayscale(image: Image.Image) -> Image.Image:
141100
gray_image = image.convert("L").convert("RGB")
142101
return gray_image
143102

144103

145-
def tile(downscale_factor, input_image):
104+
def tile(downscale_factor: int, input_image: Image.Image) -> Image.Image:
146105
control_image = input_image.resize(
147106
(input_image.size[0] // downscale_factor, input_image.size[1] // downscale_factor)
148107
).resize(input_image.size, Image.Resampling.NEAREST)
149108
return control_image
150109

151110

152-
def resize_img(control_image):
111+
def resize_img(control_image: Image.Image) -> Image.Image:
153112
image_ratio = control_image.width / control_image.height
154113
ratio = min(RATIO_CONFIGS_1024.keys(), key=lambda k: abs(k - image_ratio))
155114
to_height = RATIO_CONFIGS_1024[ratio]["height"]

invokeai/app/invocations/bria_decoder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
classification=Classification.Prototype,
1818
)
1919
class BriaDecoderInvocation(BaseInvocation):
20+
"""
21+
Decode Bria latents to an image.
22+
"""
23+
2024
vae: VAEField = InputField(
2125
description=FieldDescriptions.vae,
2226
input=Input.Connection,

invokeai/app/invocations/bria_denoiser.py

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
1-
from typing import List, Tuple
1+
from typing import Callable, List, Tuple
22

33
import torch
44
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
55
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
66

77
from invokeai.app.invocations.bria_controlnet import BriaControlNetField
8-
from invokeai.app.invocations.fields import Input, InputField, LatentsField, OutputField
8+
from invokeai.app.invocations.bria_latent_noise import BriaLatentNoiseOutput
9+
from invokeai.app.invocations.fields import FluxConditioningField, Input, InputField, LatentsField, OutputField
910
from invokeai.app.invocations.model import SubModelType, T5EncoderField, TransformerField, VAEField
1011
from invokeai.app.invocations.primitives import BaseInvocationOutput, FieldDescriptions
1112
from invokeai.app.services.shared.invocation_context import InvocationContext
1213
from invokeai.backend.bria.controlnet_bria import BriaControlModes, BriaMultiControlNetModel
1314
from invokeai.backend.bria.controlnet_utils import prepare_control_images
1415
from invokeai.backend.bria.pipeline_bria_controlnet import BriaControlNetPipeline
1516
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
17+
from invokeai.backend.model_manager.taxonomy import BaseModelType
18+
from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState
1619
from invokeai.invocation_api import BaseInvocation, Classification, invocation, invocation_output
1720

1821

@@ -30,6 +33,11 @@ class BriaDenoiseInvocationOutput(BaseInvocationOutput):
3033
classification=Classification.Prototype,
3134
)
3235
class BriaDenoiseInvocation(BaseInvocation):
36+
37+
"""
38+
Denoise Bria latents using a Bria Pipeline.
39+
"""
40+
3341
num_steps: int = InputField(
3442
default=30, title="Number of Steps", description="The number of steps to use for the denoiser"
3543
)
@@ -52,31 +60,31 @@ class BriaDenoiseInvocation(BaseInvocation):
5260
input=Input.Connection,
5361
title="VAE",
5462
)
55-
latents: LatentsField = InputField(
56-
description="Latents to denoise",
57-
input=Input.Connection,
58-
title="Latents",
63+
height: int = InputField(
64+
default=1024,
65+
title="Height",
66+
description="The height of the output image",
5967
)
60-
latent_image_ids: LatentsField = InputField(
61-
description="Latent Image IDs to denoise",
68+
width: int = InputField(
69+
default=1024,
70+
title="Width",
71+
description="The width of the output image",
72+
)
73+
latent_noise: BriaLatentNoiseOutput = InputField(
74+
description="Latent noise to denoise",
6275
input=Input.Connection,
63-
title="Latent Image IDs",
76+
title="Latent Noise",
6477
)
65-
pos_embeds: LatentsField = InputField(
78+
pos_embeds: FluxConditioningField = InputField(
6679
description="Positive Prompt Embeds",
6780
input=Input.Connection,
6881
title="Positive Prompt Embeds",
6982
)
70-
neg_embeds: LatentsField = InputField(
83+
neg_embeds: FluxConditioningField = InputField(
7184
description="Negative Prompt Embeds",
7285
input=Input.Connection,
7386
title="Negative Prompt Embeds",
7487
)
75-
text_ids: LatentsField = InputField(
76-
description="Text IDs",
77-
input=Input.Connection,
78-
title="Text IDs",
79-
)
8088
control: BriaControlNetField | list[BriaControlNetField] | None = InputField(
8189
description="ControlNet",
8290
input=Input.Connection,
@@ -86,11 +94,10 @@ class BriaDenoiseInvocation(BaseInvocation):
8694

8795
@torch.no_grad()
8896
def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
89-
latents = context.tensors.load(self.latents.latents_name)
90-
pos_embeds = context.tensors.load(self.pos_embeds.latents_name)
91-
neg_embeds = context.tensors.load(self.neg_embeds.latents_name)
92-
text_ids = context.tensors.load(self.text_ids.latents_name)
93-
latent_image_ids = context.tensors.load(self.latent_image_ids.latents_name)
97+
latents = context.tensors.load(self.latent_noise.latents.latents_name)
98+
pos_embeds = context.tensors.load(self.pos_embeds.conditioning_name)
99+
neg_embeds = context.tensors.load(self.neg_embeds.conditioning_name)
100+
latent_image_ids = context.tensors.load(self.latent_noise.latent_image_ids.latents_name)
94101
scheduler_identifier = self.transformer.transformer.model_copy(update={"submodel_type": SubModelType.Scheduler})
95102

96103
device = None
@@ -114,11 +121,12 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
114121
control_model, control_images, control_modes, control_scales = self._prepare_multi_control(
115122
context=context,
116123
vae=vae,
117-
width=1024,
118-
height=1024,
124+
width=self.width,
125+
height=self.height,
119126
device=vae.device,
120127
)
121128

129+
122130
pipeline = BriaControlNetPipeline(
123131
transformer=transformer,
124132
scheduler=scheduler,
@@ -129,31 +137,32 @@ def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
129137
)
130138
pipeline.to(device=transformer.device, dtype=transformer.dtype)
131139

132-
latents = pipeline(
140+
output_latents = pipeline(
133141
control_image=control_images,
134142
control_mode=control_modes,
135-
width=1024,
136-
height=1024,
143+
width=self.width,
144+
height=self.height,
137145
controlnet_conditioning_scale=control_scales,
138146
num_inference_steps=self.num_steps,
139147
max_sequence_length=128,
140148
guidance_scale=self.guidance_scale,
141149
latents=latents,
142150
latent_image_ids=latent_image_ids,
143-
text_ids=text_ids,
144151
prompt_embeds=pos_embeds,
145152
negative_prompt_embeds=neg_embeds,
146153
output_type="latent",
154+
step_callback=_build_step_callback(context),
147155
)[0]
148156

149-
assert isinstance(latents, torch.Tensor)
150-
saved_input_latents_tensor = context.tensors.save(latents)
151-
latents_output = LatentsField(latents_name=saved_input_latents_tensor)
152-
return BriaDenoiseInvocationOutput(latents=latents_output)
157+
158+
159+
assert isinstance(output_latents, torch.Tensor)
160+
saved_input_latents_tensor = context.tensors.save(output_latents)
161+
return BriaDenoiseInvocationOutput(latents=LatentsField(latents_name=saved_input_latents_tensor))
153162

154163
def _prepare_multi_control(
155164
self, context: InvocationContext, vae: AutoencoderKL, width: int, height: int, device: torch.device
156-
) -> Tuple[BriaMultiControlNetModel, List[torch.Tensor], List[torch.Tensor], List[float]]:
165+
) -> Tuple[BriaMultiControlNetModel, List[torch.Tensor], List[int], List[float]]:
157166
control = self.control if isinstance(self.control, list) else [self.control]
158167
control_images, control_models, control_modes, control_scales = [], [], [], []
159168
for controlnet in control:
@@ -178,3 +187,11 @@ def _prepare_multi_control(
178187
device=device,
179188
)
180189
return control_model, tensored_control_images, tensored_control_modes, control_scales
190+
191+
192+
def _build_step_callback(context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
193+
def step_callback(state: PipelineIntermediateState) -> None:
194+
return
195+
context.util.sd_step_callback(state, BaseModelType.Bria)
196+
197+
return step_callback

invokeai/app/invocations/bria_latent_sampler.py renamed to invokeai/app/invocations/bria_latent_noise.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from pydantic import BaseModel, Field
23

34
from invokeai.app.invocations.fields import Input, InputField, OutputField
45
from invokeai.app.invocations.model import TransformerField
@@ -17,23 +18,28 @@
1718
)
1819

1920

20-
@invocation_output("bria_latent_sampler_output")
21-
class BriaLatentSamplerInvocationOutput(BaseInvocationOutput):
22-
"""Base class for nodes that output a CogView text conditioning tensor."""
23-
24-
latents: LatentsField = OutputField(description=FieldDescriptions.cond)
25-
latent_image_ids: LatentsField = OutputField(description=FieldDescriptions.cond)
21+
class BriaLatentNoiseOutput(BaseModel):
22+
latents: LatentsField
23+
latent_image_ids: LatentsField
2624

25+
@invocation_output("bria_latent_noise_output")
26+
class BriaLatentNoiseInvocationOutput(BaseInvocationOutput):
27+
"""Base class for nodes that output Bria latent tensors."""
28+
latent_noise: BriaLatentNoiseOutput = OutputField(description="The latent noise, containing latents and latent image ids.")
29+
height: int = OutputField(description="The height of the output image", default=1024)
30+
width: int = OutputField(description="The width of the output image", default=1024)
2731

2832
@invocation(
29-
"bria_latent_sampler",
30-
title="Latent Sampler - Bria",
33+
"bria_latent_noise",
34+
title="Latent Noise - Bria",
3135
tags=["image", "bria"],
3236
category="image",
3337
version="1.0.0",
3438
classification=Classification.Prototype,
3539
)
36-
class BriaLatentSamplerInvocation(BaseInvocation):
40+
class BriaLatentNoiseInvocation(BaseInvocation):
41+
""" Generate latent noise for Bria. """
42+
3743
seed: int = InputField(
3844
default=42,
3945
title="Seed",
@@ -44,22 +50,31 @@ class BriaLatentSamplerInvocation(BaseInvocation):
4450
input=Input.Connection,
4551
title="Transformer",
4652
)
53+
height: int = InputField(
54+
default=1024,
55+
title="Height",
56+
description="The height of the output image",
57+
)
58+
width: int = InputField(
59+
default=1024,
60+
title="Width",
61+
description="The width of the output image",
62+
)
4763

4864
@torch.no_grad()
49-
def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutput:
65+
def invoke(self, context: InvocationContext) -> BriaLatentNoiseInvocationOutput:
5066
with context.models.load(self.transformer.transformer) as transformer:
5167
device = transformer.device
5268
dtype = transformer.dtype
5369

54-
height, width = 1024, 1024
5570
generator = torch.Generator(device=device).manual_seed(self.seed)
5671

5772
num_channels_latents = 4
5873
latents, latent_image_ids = prepare_latents(
5974
batch_size=1,
6075
num_channels_latents=num_channels_latents,
61-
height=height,
62-
width=width,
76+
height=self.height,
77+
width=self.width,
6378
dtype=dtype,
6479
device=device,
6580
generator=generator,
@@ -70,7 +85,11 @@ def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutpu
7085
latents_output = LatentsField(latents_name=saved_latents_tensor)
7186
latent_image_ids_output = LatentsField(latents_name=saved_latent_image_ids_tensor)
7287

73-
return BriaLatentSamplerInvocationOutput(
74-
latents=latents_output,
75-
latent_image_ids=latent_image_ids_output,
88+
return BriaLatentNoiseInvocationOutput(
89+
latent_noise=BriaLatentNoiseOutput(
90+
latents=latents_output,
91+
latent_image_ids=latent_image_ids_output,
92+
),
93+
height=self.height,
94+
width=self.width,
7695
)

0 commit comments

Comments
 (0)