Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9cc852c
Base code from draft PR
StAlKeR7779 Jul 12, 2024
0bc6037
A bit rework conditioning convert to unet kwargs
StAlKeR7779 Jul 12, 2024
87e96e1
Rename modifiers to callbacks, convert order to int, a bit unify inje…
StAlKeR7779 Jul 12, 2024
bd8ae5d
Simplify guidance modes
StAlKeR7779 Jul 12, 2024
3a9dda9
Renames
StAlKeR7779 Jul 12, 2024
7e00526
Remove overrides logic for now
StAlKeR7779 Jul 12, 2024
e961dd1
Remove remains of priority logic
StAlKeR7779 Jul 12, 2024
499e4d4
Add preview extension to check logic
StAlKeR7779 Jul 12, 2024
d623bd4
Fix condtionings logic
StAlKeR7779 Jul 15, 2024
fd8d1c1
Remove 'del' operator overload
StAlKeR7779 Jul 15, 2024
9f088d1
Multiple small fixes
StAlKeR7779 Jul 15, 2024
608cbe3
Separate inputs in denoise context
StAlKeR7779 Jul 16, 2024
cec345c
Change attention processor apply logic
StAlKeR7779 Jul 16, 2024
b7c6c63
Added some comments
StAlKeR7779 Jul 16, 2024
cd1bc15
Rename sequential as private variable
StAlKeR7779 Jul 17, 2024
ae6d4fb
Move out _concat_conditionings_for_batch submethods
StAlKeR7779 Jul 17, 2024
03e22c2
Convert conditioning_mode to enum
StAlKeR7779 Jul 17, 2024
137202b
Remove patch_unet logic for now
StAlKeR7779 Jul 17, 2024
79e35bd
Minor fixes
StAlKeR7779 Jul 17, 2024
2c2ec8f
Comments, a bit refactor
StAlKeR7779 Jul 17, 2024
3f79467
Ruff format
StAlKeR7779 Jul 17, 2024
2ef3b49
Add run cancelling logic to extension manager
StAlKeR7779 Jul 17, 2024
710dc6b
Merge branch 'main' into stalker7779/backend_base
StAlKeR7779 Jul 17, 2024
0c56d4a
Ryan's suggested changes to extension manager/extensions
StAlKeR7779 Jul 18, 2024
83a86ab
Add unit tests for ExtensionsManager and ExtensionBase.
RyanJDick Jul 19, 2024
39e10d8
Add invocation cancellation logic to patchers
StAlKeR7779 Jul 19, 2024
78d2b1b
Merge branch 'main' into stalker-backend_base
RyanJDick Jul 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 116 additions & 7 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect
import os
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

Expand Down Expand Up @@ -39,6 +40,7 @@
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
StableDiffusionGeneratorPipeline,
Expand All @@ -53,6 +55,10 @@
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import TorchDevice
Expand Down Expand Up @@ -314,9 +320,10 @@ def get_conditioning_data(
context: InvocationContext,
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
unet: UNet2DConditionModel,
latent_height: int,
latent_width: int,
device: torch.device,
dtype: torch.dtype,
cfg_scale: float | list[float],
steps: int,
cfg_rescale_multiplier: float,
Expand All @@ -330,25 +337,25 @@ def get_conditioning_data(
uncond_list = [uncond_list]

cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
cond_list, context, device, dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype
uncond_list, context, device, dtype
)

cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
dtype=dtype,
)
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
dtype=dtype,
)

if isinstance(cfg_scale, list):
Expand Down Expand Up @@ -707,9 +714,110 @@ def prepare_noise_and_latents(

return seed, noise, latents

def invoke(self, context: InvocationContext) -> LatentsOutput:
if os.environ.get("USE_MODULAR_DENOISE", False):
return self._new_invoke(context)
else:
return self._old_invoke(context)

@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
# TODO: remove supression when extensions which use models added
with ExitStack() as exit_stack: # noqa: F841
ext_manager = ExtensionsManager()

device = TorchDevice.choose_torch_device()
dtype = TorchDevice.choose_torch_dtype()

seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)

_, _, latent_height, latent_width = latents.shape

conditioning_data = self.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
cfg_scale=self.cfg_scale,
steps=self.steps,
latent_height=latent_height,
latent_width=latent_width,
device=device,
dtype=dtype,
# TODO: old backend, remove
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)

scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)

timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
seed=seed,
device=device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
)

denoise_ctx = DenoiseContext(
inputs=DenoiseInputs(
orig_latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
attention_processor_cls=CustomAttnProcessor2_0,
),
unet=None,
scheduler=scheduler,
)

### preview
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)

ext_manager.add_extension(PreviewExt(step_callback))

# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)

# ext: t2i/ip adapter
ext_manager.callbacks.setup(denoise_ctx, ext_manager)

unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (model_state_dict, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet
ext_manager.patch_extensions(unet),
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet),
):
sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.detach().to("cpu")
TorchDevice.empty_cache()

name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)

@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)

mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
Expand Down Expand Up @@ -788,7 +896,8 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
device=unet.device,
dtype=unet.dtype,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
Expand Down
23 changes: 21 additions & 2 deletions invokeai/backend/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pickle
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union

import numpy as np
import torch
Expand All @@ -32,8 +32,27 @@
"""


# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher:
@staticmethod
@contextmanager
def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]):
"""A context manager that patches `unet` with the provided attention processor.

Args:
unet (UNet2DConditionModel): The UNet model to patch.
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
"""
unet_orig_processors = unet.attn_processors
try:
# create separate instance for each attention, to be able modify each attention separately
new_attn_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
unet.set_attn_processor(new_attn_processors)

yield None

finally:
unet.set_attn_processor(unet_orig_processors)

@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
Expand Down
122 changes: 122 additions & 0 deletions invokeai/backend/stable_diffusion/denoise_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union

import torch
from diffusers import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput

if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData


@dataclass
class UNetKwargs:
sample: torch.Tensor
timestep: Union[torch.Tensor, float, int]
encoder_hidden_states: torch.Tensor

class_labels: Optional[torch.Tensor] = None
timestep_cond: Optional[torch.Tensor] = None
attention_mask: Optional[torch.Tensor] = None
cross_attention_kwargs: Optional[Dict[str, Any]] = None
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
mid_block_additional_residual: Optional[torch.Tensor] = None
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
encoder_attention_mask: Optional[torch.Tensor] = None
# return_dict: bool = True


@dataclass
class DenoiseInputs:
"""Initial variables passed to denoise. Supposed to be unchanged.

Variables:
orig_latents: The latent-space image to denoise.
Shape: [batch, channels, latent_height, latent_width]
- If we are inpainting, this is the initial latent image before noise has been added.
- If we are generating a new image, this should be initialized to zeros.
- In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
scheduler_step_kwargs: kwargs forwarded to the scheduler.step() method.
conditioning_data: Text conditionging data.
noise: Noise used for two purposes:
Shape: [1 or batch, channels, latent_height, latent_width]
1. Used by the scheduler to noise the initial `latents` before denoising.
2. Used to noise the `masked_latents` when inpainting.
`noise` should be None if the `latents` tensor has already been noised.
seed: The seed used to generate the noise for the denoising process.
HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
same noise used earlier in the pipeline. This should really be handled in a clearer way.
timesteps: The timestep schedule for the denoising process.
init_timestep: The first timestep in the schedule. This is used to determine the initial noise level, so
should be populated if you want noise applied *even* if timesteps is empty.
attention_processor_cls: Class of attention processor that is used.
"""

orig_latents: torch.Tensor
scheduler_step_kwargs: dict[str, Any]
conditioning_data: TextConditioningData
noise: Optional[torch.Tensor]
seed: int
timesteps: torch.Tensor
init_timestep: torch.Tensor
attention_processor_cls: Type[Any]


@dataclass
class DenoiseContext:
"""Context with all variables in denoise

Variables:
inputs: Initial variables passed to denoise. Supposed to be unchanged.
scheduler: Scheduler which used to apply noise predictions.
unet: UNet model.
latents: Current state of latent-space image in denoising process.
None until `pre_denoise_loop` callback.
Shape: [batch, channels, latent_height, latent_width]
step_index: Current denoising step index.
None until `pre_step` callback.
timestep: Current denoising step timestep.
None until `pre_step` callback.
unet_kwargs: Arguments which will be passed to U Net model.
Available in `pre_unet`/`post_unet` callbacks, otherwice will be None.
step_output: SchedulerOutput class returned from step function(normally, generated by scheduler).
Supposed to be used only in `post_step` callback, otherwice can be None.
latent_model_input: Scaled version of `latents`, which will be passed to unet_kwargs initialization.
Available in events inside step(between `pre_step` and `post_stop`).
Shape: [batch, channels, latent_height, latent_width]
conditioning_mode: [TMP] Defines on which conditionings current unet call will be runned.
Available in `pre_unet`/`post_unet` callbacks, otherwice will be None.
Can be "negative", "positive" or "both"
negative_noise_pred: [TMP] Noise predictions from negative conditioning.
Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None.
Shape: [batch, channels, latent_height, latent_width]
positive_noise_pred: [TMP] Noise predictions from positive conditioning.
Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None.
Shape: [batch, channels, latent_height, latent_width]
noise_pred: Combined noise prediction from passed conditionings.
Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwice will be None.
Shape: [batch, channels, latent_height, latent_width]
extra: Dictionary for extensions to pass extra info about denoise process to other extensions.
"""

inputs: DenoiseInputs

scheduler: SchedulerMixin
unet: Optional[UNet2DConditionModel] = None

latents: Optional[torch.Tensor] = None
step_index: Optional[int] = None
timestep: Optional[torch.Tensor] = None
unet_kwargs: Optional[UNetKwargs] = None
step_output: Optional[SchedulerOutput] = None

latent_model_input: Optional[torch.Tensor] = None
conditioning_mode: Optional[str] = None
negative_noise_pred: Optional[torch.Tensor] = None
positive_noise_pred: Optional[torch.Tensor] = None
noise_pred: Optional[torch.Tensor] = None

extra: dict = field(default_factory=dict)
11 changes: 1 addition & 10 deletions invokeai/backend/stable_diffusion/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,12 @@
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel


@dataclass
class PipelineIntermediateState:
step: int
order: int
total_steps: int
timestep: int
latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None


@dataclass
class AddsMaskGuidance:
mask: torch.Tensor
Expand Down
Loading