Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9cc852c
Base code from draft PR
StAlKeR7779 Jul 12, 2024
0bc6037
A bit rework conditioning convert to unet kwargs
StAlKeR7779 Jul 12, 2024
87e96e1
Rename modifiers to callbacks, convert order to int, a bit unify inje…
StAlKeR7779 Jul 12, 2024
bd8ae5d
Simplify guidance modes
StAlKeR7779 Jul 12, 2024
3a9dda9
Renames
StAlKeR7779 Jul 12, 2024
7e00526
Remove overrides logic for now
StAlKeR7779 Jul 12, 2024
e961dd1
Remove remains of priority logic
StAlKeR7779 Jul 12, 2024
499e4d4
Add preview extension to check logic
StAlKeR7779 Jul 12, 2024
d623bd4
Fix condtionings logic
StAlKeR7779 Jul 15, 2024
fd8d1c1
Remove 'del' operator overload
StAlKeR7779 Jul 15, 2024
9f088d1
Multiple small fixes
StAlKeR7779 Jul 15, 2024
608cbe3
Separate inputs in denoise context
StAlKeR7779 Jul 16, 2024
cec345c
Change attention processor apply logic
StAlKeR7779 Jul 16, 2024
b7c6c63
Added some comments
StAlKeR7779 Jul 16, 2024
cd1bc15
Rename sequential as private variable
StAlKeR7779 Jul 17, 2024
ae6d4fb
Move out _concat_conditionings_for_batch submethods
StAlKeR7779 Jul 17, 2024
03e22c2
Convert conditioning_mode to enum
StAlKeR7779 Jul 17, 2024
137202b
Remove patch_unet logic for now
StAlKeR7779 Jul 17, 2024
79e35bd
Minor fixes
StAlKeR7779 Jul 17, 2024
2c2ec8f
Comments, a bit refactor
StAlKeR7779 Jul 17, 2024
3f79467
Ruff format
StAlKeR7779 Jul 17, 2024
2ef3b49
Add run cancelling logic to extension manager
StAlKeR7779 Jul 17, 2024
710dc6b
Merge branch 'main' into stalker7779/backend_base
StAlKeR7779 Jul 17, 2024
0c56d4a
Ryan's suggested changes to extension manager/extensions
StAlKeR7779 Jul 18, 2024
83a86ab
Add unit tests for ExtensionsManager and ExtensionBase.
RyanJDick Jul 19, 2024
39e10d8
Add invocation cancellation logic to patchers
StAlKeR7779 Jul 19, 2024
78d2b1b
Merge branch 'main' into stalker-backend_base
RyanJDick Jul 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 115 additions & 7 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect
import os
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

Expand Down Expand Up @@ -39,6 +40,7 @@
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
StableDiffusionGeneratorPipeline,
Expand All @@ -53,6 +55,11 @@
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import TorchDevice
Expand Down Expand Up @@ -314,9 +321,10 @@ def get_conditioning_data(
context: InvocationContext,
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
unet: UNet2DConditionModel,
latent_height: int,
latent_width: int,
device: torch.device,
dtype: torch.dtype,
cfg_scale: float | list[float],
steps: int,
cfg_rescale_multiplier: float,
Expand All @@ -330,25 +338,25 @@ def get_conditioning_data(
uncond_list = [uncond_list]

cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
cond_list, context, device, dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype
uncond_list, context, device, dtype
)

cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
dtype=dtype,
)
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
dtype=dtype,
)

if isinstance(cfg_scale, list):
Expand Down Expand Up @@ -707,9 +715,108 @@ def prepare_noise_and_latents(

return seed, noise, latents

def invoke(self, context: InvocationContext) -> LatentsOutput:
if os.environ.get("USE_MODULAR_DENOISE", False):
return self._new_invoke(context)
else:
return self._old_invoke(context)

@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
ext_manager = ExtensionsManager(is_canceled=context.util.is_canceled)

device = TorchDevice.choose_torch_device()
dtype = TorchDevice.choose_torch_dtype()

seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)

_, _, latent_height, latent_width = latents.shape

conditioning_data = self.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
cfg_scale=self.cfg_scale,
steps=self.steps,
latent_height=latent_height,
latent_width=latent_width,
device=device,
dtype=dtype,
# TODO: old backend, remove
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)

scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)

timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
seed=seed,
device=device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
)

denoise_ctx = DenoiseContext(
inputs=DenoiseInputs(
orig_latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
attention_processor_cls=CustomAttnProcessor2_0,
),
unet=None,
scheduler=scheduler,
)

# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)

### preview
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)

ext_manager.add_extension(PreviewExt(step_callback))

# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)

unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (model_state_dict, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet
ext_manager.patch_extensions(unet),
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet),
):
sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.detach().to("cpu")
TorchDevice.empty_cache()

name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)

@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)

mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
Expand Down Expand Up @@ -788,7 +895,8 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
device=unet.device,
dtype=unet.dtype,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
Expand Down
23 changes: 21 additions & 2 deletions invokeai/backend/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pickle
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union

import numpy as np
import torch
Expand All @@ -32,8 +32,27 @@
"""


# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher:
@staticmethod
@contextmanager
def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]):
"""A context manager that patches `unet` with the provided attention processor.

Args:
unet (UNet2DConditionModel): The UNet model to patch.
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
"""
unet_orig_processors = unet.attn_processors

# create separate instance for each attention, to be able modify each attention separately
unet_new_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
try:
unet.set_attn_processor(unet_new_processors)
yield None

finally:
unet.set_attn_processor(unet_orig_processors)

@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
Expand Down
131 changes: 131 additions & 0 deletions invokeai/backend/stable_diffusion/denoise_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union

import torch
from diffusers import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput

if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode, TextConditioningData


@dataclass
class UNetKwargs:
sample: torch.Tensor
timestep: Union[torch.Tensor, float, int]
encoder_hidden_states: torch.Tensor

class_labels: Optional[torch.Tensor] = None
timestep_cond: Optional[torch.Tensor] = None
attention_mask: Optional[torch.Tensor] = None
cross_attention_kwargs: Optional[Dict[str, Any]] = None
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
mid_block_additional_residual: Optional[torch.Tensor] = None
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
encoder_attention_mask: Optional[torch.Tensor] = None
# return_dict: bool = True


@dataclass
class DenoiseInputs:
"""Initial variables passed to denoise. Supposed to be unchanged."""

# The latent-space image to denoise.
# Shape: [batch, channels, latent_height, latent_width]
# - If we are inpainting, this is the initial latent image before noise has been added.
# - If we are generating a new image, this should be initialized to zeros.
# - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
orig_latents: torch.Tensor

# kwargs forwarded to the scheduler.step() method.
scheduler_step_kwargs: dict[str, Any]

# Text conditionging data.
conditioning_data: TextConditioningData

# Noise used for two purposes:
# 1. Used by the scheduler to noise the initial `latents` before denoising.
# 2. Used to noise the `masked_latents` when inpainting.
# `noise` should be None if the `latents` tensor has already been noised.
# Shape: [1 or batch, channels, latent_height, latent_width]
noise: Optional[torch.Tensor]

# The seed used to generate the noise for the denoising process.
# HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
# same noise used earlier in the pipeline. This should really be handled in a clearer way.
seed: int

# The timestep schedule for the denoising process.
timesteps: torch.Tensor

# The first timestep in the schedule. This is used to determine the initial noise level, so
# should be populated if you want noise applied *even* if timesteps is empty.
init_timestep: torch.Tensor

# Class of attention processor that is used.
attention_processor_cls: Type[Any]


@dataclass
class DenoiseContext:
"""Context with all variables in denoise"""

# Initial variables passed to denoise. Supposed to be unchanged.
inputs: DenoiseInputs

# Scheduler which used to apply noise predictions.
scheduler: SchedulerMixin

# UNet model.
unet: Optional[UNet2DConditionModel] = None

# Current state of latent-space image in denoising process.
# None until `pre_denoise_loop` callback.
# Shape: [batch, channels, latent_height, latent_width]
latents: Optional[torch.Tensor] = None

# Current denoising step index.
# None until `pre_step` callback.
step_index: Optional[int] = None

# Current denoising step timestep.
# None until `pre_step` callback.
timestep: Optional[torch.Tensor] = None

# Arguments which will be passed to UNet model.
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
unet_kwargs: Optional[UNetKwargs] = None

# SchedulerOutput class returned from step function(normally, generated by scheduler).
# Supposed to be used only in `post_step` callback, otherwise can be None.
step_output: Optional[SchedulerOutput] = None

# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
# Available in events inside step(between `pre_step` and `post_stop`).
# Shape: [batch, channels, latent_height, latent_width]
latent_model_input: Optional[torch.Tensor] = None

# [TMP] Defines on which conditionings current unet call will be runned.
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
conditioning_mode: Optional[ConditioningMode] = None

# [TMP] Noise predictions from negative conditioning.
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
# Shape: [batch, channels, latent_height, latent_width]
negative_noise_pred: Optional[torch.Tensor] = None

# [TMP] Noise predictions from positive conditioning.
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
# Shape: [batch, channels, latent_height, latent_width]
positive_noise_pred: Optional[torch.Tensor] = None

# Combined noise prediction from passed conditionings.
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
# Shape: [batch, channels, latent_height, latent_width]
noise_pred: Optional[torch.Tensor] = None

# Dictionary for extensions to pass extra info about denoise process to other extensions.
extra: dict = field(default_factory=dict)
11 changes: 1 addition & 10 deletions invokeai/backend/stable_diffusion/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,12 @@
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel


@dataclass
class PipelineIntermediateState:
step: int
order: int
total_steps: int
timestep: int
latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None


@dataclass
class AddsMaskGuidance:
mask: torch.Tensor
Expand Down
Loading