Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9cc852c
Base code from draft PR
StAlKeR7779 Jul 12, 2024
0bc6037
A bit rework conditioning convert to unet kwargs
StAlKeR7779 Jul 12, 2024
87e96e1
Rename modifiers to callbacks, convert order to int, a bit unify inje…
StAlKeR7779 Jul 12, 2024
bd8ae5d
Simplify guidance modes
StAlKeR7779 Jul 12, 2024
3a9dda9
Renames
StAlKeR7779 Jul 12, 2024
7e00526
Remove overrides logic for now
StAlKeR7779 Jul 12, 2024
e961dd1
Remove remains of priority logic
StAlKeR7779 Jul 12, 2024
499e4d4
Add preview extension to check logic
StAlKeR7779 Jul 12, 2024
d623bd4
Fix condtionings logic
StAlKeR7779 Jul 15, 2024
fd8d1c1
Remove 'del' operator overload
StAlKeR7779 Jul 15, 2024
9f088d1
Multiple small fixes
StAlKeR7779 Jul 15, 2024
608cbe3
Separate inputs in denoise context
StAlKeR7779 Jul 16, 2024
cec345c
Change attention processor apply logic
StAlKeR7779 Jul 16, 2024
b7c6c63
Added some comments
StAlKeR7779 Jul 16, 2024
cd1bc15
Rename sequential as private variable
StAlKeR7779 Jul 17, 2024
ae6d4fb
Move out _concat_conditionings_for_batch submethods
StAlKeR7779 Jul 17, 2024
03e22c2
Convert conditioning_mode to enum
StAlKeR7779 Jul 17, 2024
137202b
Remove patch_unet logic for now
StAlKeR7779 Jul 17, 2024
79e35bd
Minor fixes
StAlKeR7779 Jul 17, 2024
2c2ec8f
Comments, a bit refactor
StAlKeR7779 Jul 17, 2024
3f79467
Ruff format
StAlKeR7779 Jul 17, 2024
2ef3b49
Add run cancelling logic to extension manager
StAlKeR7779 Jul 17, 2024
710dc6b
Merge branch 'main' into stalker7779/backend_base
StAlKeR7779 Jul 17, 2024
0c56d4a
Ryan's suggested changes to extension manager/extensions
StAlKeR7779 Jul 18, 2024
83a86ab
Add unit tests for ExtensionsManager and ExtensionBase.
RyanJDick Jul 19, 2024
39e10d8
Add invocation cancellation logic to patchers
StAlKeR7779 Jul 19, 2024
78d2b1b
Merge branch 'main' into stalker-backend_base
RyanJDick Jul 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 111 additions & 7 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect
import os
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

Expand Down Expand Up @@ -39,6 +40,7 @@
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
StableDiffusionGeneratorPipeline,
Expand All @@ -53,6 +55,10 @@
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extensions import PreviewExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import TorchDevice
Expand Down Expand Up @@ -314,9 +320,10 @@ def get_conditioning_data(
context: InvocationContext,
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
unet: UNet2DConditionModel,
latent_height: int,
latent_width: int,
device: torch.device,
dtype: torch.dtype,
cfg_scale: float | list[float],
steps: int,
cfg_rescale_multiplier: float,
Expand All @@ -330,25 +337,25 @@ def get_conditioning_data(
uncond_list = [uncond_list]

cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
cond_list, context, device, dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype
uncond_list, context, device, dtype
)

cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
dtype=dtype,
)
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
dtype=dtype,
)

if isinstance(cfg_scale, list):
Expand Down Expand Up @@ -707,9 +714,105 @@ def prepare_noise_and_latents(

return seed, noise, latents

def invoke(self, context: InvocationContext) -> LatentsOutput:
if os.environ.get("USE_MODULAR_DENOISE", False):
return self._new_invoke(context)
else:
return self._old_invoke(context)

@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
with ExitStack() as exit_stack:
ext_manager = ExtensionsManager()

device = TorchDevice.choose_torch_device()
dtype = TorchDevice.choose_torch_dtype()

seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)

_, _, latent_height, latent_width = latents.shape

conditioning_data = self.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
cfg_scale=self.cfg_scale,
steps=self.steps,
latent_height=latent_height,
latent_width=latent_width,
device=device,
dtype=dtype,
# TODO: old backend, remove
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)

scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)

timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
seed=seed,
device=device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
)

denoise_ctx = DenoiseContext(
latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
unet=None,
scheduler=scheduler,
)

### preview
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)

ext_manager.add_extension(PreviewExt(step_callback))

# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)

# ext: t2i/ip adapter
ext_manager.callbacks.setup(denoise_ctx, ext_manager)

unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (model_state_dict, unet),
# ext: controlnet
ext_manager.patch_attention_processor(unet, CustomAttnProcessor2_0),
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet),
):
sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu") # TODO: detach?
TorchDevice.empty_cache()

name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)

@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)

mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
Expand Down Expand Up @@ -788,7 +891,8 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
device=unet.device,
dtype=unet.dtype,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
Expand Down
60 changes: 60 additions & 0 deletions invokeai/backend/stable_diffusion/denoise_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union

import torch
from diffusers import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput

if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData


@dataclass
class UNetKwargs:
sample: torch.Tensor
timestep: Union[torch.Tensor, float, int]
encoder_hidden_states: torch.Tensor

class_labels: Optional[torch.Tensor] = None
timestep_cond: Optional[torch.Tensor] = None
attention_mask: Optional[torch.Tensor] = None
cross_attention_kwargs: Optional[Dict[str, Any]] = None
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
mid_block_additional_residual: Optional[torch.Tensor] = None
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
encoder_attention_mask: Optional[torch.Tensor] = None
# return_dict: bool = True


@dataclass
class DenoiseContext:
latents: torch.Tensor
scheduler_step_kwargs: dict[str, Any]
conditioning_data: TextConditioningData
noise: Optional[torch.Tensor]
seed: int
timesteps: torch.Tensor
init_timestep: torch.Tensor

scheduler: SchedulerMixin
unet: Optional[UNet2DConditionModel] = None

orig_latents: Optional[torch.Tensor] = None
step_index: Optional[int] = None
timestep: Optional[torch.Tensor] = None
unet_kwargs: Optional[UNetKwargs] = None
step_output: Optional[SchedulerOutput] = None

latent_model_input: Optional[torch.Tensor] = None
conditioning_mode: Optional[str] = None
negative_noise_pred: Optional[torch.Tensor] = None
positive_noise_pred: Optional[torch.Tensor] = None
noise_pred: Optional[torch.Tensor] = None

extra: dict = field(default_factory=dict)

def __delattr__(self, name: str):
setattr(self, name, None)
18 changes: 9 additions & 9 deletions invokeai/backend/stable_diffusion/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.stable_diffusion.extensions import PipelineIntermediateState
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel


@dataclass
class PipelineIntermediateState:
step: int
order: int
total_steps: int
timestep: int
latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None
# @dataclass
# class PipelineIntermediateState:
# step: int
# order: int
# total_steps: int
# timestep: int
# latents: torch.Tensor
# predicted_original: Optional[torch.Tensor] = None


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData


@dataclass
Expand Down Expand Up @@ -103,7 +104,7 @@ def __init__(
uncond_regions: Optional[TextConditioningRegions],
cond_regions: Optional[TextConditioningRegions],
guidance_scale: Union[float, List[float]],
guidance_rescale_multiplier: float = 0,
guidance_rescale_multiplier: float = 0, # TODO: old backend, remove
):
self.uncond_text = uncond_text
self.cond_text = cond_text
Expand All @@ -114,10 +115,96 @@ def __init__(
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
self.guidance_scale = guidance_scale
# TODO: old backend, remove
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
self.guidance_rescale_multiplier = guidance_rescale_multiplier

def is_sdxl(self):
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
return isinstance(self.cond_text, SDXLConditioningInfo)

def to_unet_kwargs(self, unet_kwargs, conditioning_mode):
_, _, h, w = unet_kwargs.sample.shape
device = unet_kwargs.sample.device
dtype = unet_kwargs.sample.dtype

# TODO: combine regions with conditionings
if conditioning_mode == "both":
conditionings = [self.uncond_text.embeds, self.cond_text.embeds]
c_regions = [self.uncond_regions, self.cond_regions]
elif conditioning_mode == "positive":
conditionings = [self.cond_text.embeds]
c_regions = [self.cond_regions]
else:
conditionings = [self.uncond_text.embeds]
c_regions = [self.uncond_regions]

encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(conditionings)

unet_kwargs.encoder_hidden_states = encoder_hidden_states
unet_kwargs.encoder_attention_mask = encoder_attention_mask

if self.is_sdxl():
added_cond_kwargs = dict( # noqa: C408
text_embeds=torch.cat([c.pooled_embeds for c in conditionings]),
time_ids=torch.cat([c.add_time_ids for c in conditionings]),
)

unet_kwargs.added_cond_kwargs = added_cond_kwargs

if any(r is not None for r in c_regions):
tmp_regions = []
for c, r in zip(conditionings, c_regions, strict=True):
if r is None:
r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=dtype),
ranges=[Range(start=0, end=c.embeds.shape[1])],
)
tmp_regions.append(r)

if unet_kwargs.cross_attention_kwargs is None:
unet_kwargs.cross_attention_kwargs = {}

unet_kwargs.cross_attention_kwargs.update(
regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
)

def _concat_conditionings_for_batch(self, conditionings):
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int):
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)

def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones(
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
)

if cond.shape[1] < max_len:
conditioning_attention_mask = _pad_zeros(
conditioning_attention_mask,
pad_shape=(cond.shape[0], max_len - cond.shape[1]),
dim=1,
)

cond = _pad_zeros(
cond,
pad_shape=(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
dim=1,
)

if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask])

return cond, encoder_attention_mask

encoder_attention_mask = None
max_len = max([c.shape[1] for c in conditionings])
if any(c.shape[1] != max_len for c in conditionings):
for i in range(len(conditionings)):
conditionings[i], encoder_attention_mask = _pad_conditioning(
conditionings[i], max_len, encoder_attention_mask
)

return torch.cat(conditionings), encoder_attention_mask
Loading