Skip to content

Commit 693a3ea

Browse files
committed
Merge branch 'main' into stalker-modular_inpaint-2
2 parents 84d0288 + 171a4e6 commit 693a3ea

File tree

172 files changed

+5710
-2809
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

172 files changed

+5710
-2809
lines changed

docker/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
5555
FROM node:20-slim AS web-builder
5656
ENV PNPM_HOME="/pnpm"
5757
ENV PATH="$PNPM_HOME:$PATH"
58+
RUN corepack use [email protected]
5859
RUN corepack enable
5960

6061
WORKDIR /build

invokeai/app/api/routers/model_manager.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import traceback
77
from copy import deepcopy
88
from tempfile import TemporaryDirectory
9-
from typing import Any, Dict, List, Optional, Type
9+
from typing import List, Optional, Type
1010

1111
from fastapi import Body, Path, Query, Response, UploadFile
1212
from fastapi.responses import FileResponse, HTMLResponse
@@ -430,13 +430,11 @@ async def delete_model_image(
430430
async def install_model(
431431
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
432432
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
433-
# TODO(MM2): Can we type this?
434-
config: Optional[Dict[str, Any]] = Body(
435-
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
436-
default=None,
433+
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
434+
config: ModelRecordChanges = Body(
435+
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
437436
example={"name": "string", "description": "string"},
438437
),
439-
access_token: Optional[str] = None,
440438
) -> ModelInstallJob:
441439
"""Install a model using a string identifier.
442440
@@ -451,8 +449,9 @@ async def install_model(
451449
- model/name:fp16:path/to/model.safetensors
452450
- model/name::path/to/model.safetensors
453451
454-
`config` is an optional dict containing model configuration values that will override
455-
the ones that are probed automatically.
452+
`config` is a ModelRecordChanges object. Fields in this object will override
453+
the ones that are probed automatically. Pass an empty object to accept
454+
all the defaults.
456455
457456
`access_token` is an optional access token for use with Urls that require
458457
authentication.
@@ -737,7 +736,7 @@ async def convert_model(
737736
# write the converted file to the convert path
738737
raw_model = converted_model.model
739738
assert hasattr(raw_model, "save_pretrained")
740-
raw_model.save_pretrained(convert_path)
739+
raw_model.save_pretrained(convert_path) # type: ignore
741740
assert convert_path.exists()
742741

743742
# temporarily rename the original safetensors file so that there is no naming conflict
@@ -750,12 +749,12 @@ async def convert_model(
750749
try:
751750
new_key = installer.install_path(
752751
convert_path,
753-
config={
754-
"name": original_name,
755-
"description": model_config.description,
756-
"hash": model_config.hash,
757-
"source": model_config.source,
758-
},
752+
config=ModelRecordChanges(
753+
name=original_name,
754+
description=model_config.description,
755+
hash=model_config.hash,
756+
source=model_config.source,
757+
),
759758
)
760759
except Exception as e:
761760
logger.error(str(e))

invokeai/app/invocations/denoise_latents.py

Lines changed: 103 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from invokeai.backend.lora import LoRAModelRaw
4040
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
4141
from invokeai.backend.model_patcher import ModelPatcher
42-
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
42+
from invokeai.backend.stable_diffusion import PipelineIntermediateState
4343
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
4444
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
4545
ControlNetData,
@@ -58,9 +58,14 @@
5858
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
5959
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
6060
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
61+
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
62+
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
6163
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
6264
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
6365
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
66+
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
67+
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
68+
from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt
6469
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
6570
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
6671
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@@ -465,6 +470,65 @@ def prep_control_data(
465470

466471
return controlnet_data
467472

473+
@staticmethod
474+
def parse_controlnet_field(
475+
exit_stack: ExitStack,
476+
context: InvocationContext,
477+
control_input: ControlField | list[ControlField] | None,
478+
ext_manager: ExtensionsManager,
479+
) -> None:
480+
# Normalize control_input to a list.
481+
control_list: list[ControlField]
482+
if isinstance(control_input, ControlField):
483+
control_list = [control_input]
484+
elif isinstance(control_input, list):
485+
control_list = control_input
486+
elif control_input is None:
487+
control_list = []
488+
else:
489+
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
490+
491+
for control_info in control_list:
492+
model = exit_stack.enter_context(context.models.load(control_info.control_model))
493+
ext_manager.add_extension(
494+
ControlNetExt(
495+
model=model,
496+
image=context.images.get_pil(control_info.image.image_name),
497+
weight=control_info.control_weight,
498+
begin_step_percent=control_info.begin_step_percent,
499+
end_step_percent=control_info.end_step_percent,
500+
control_mode=control_info.control_mode,
501+
resize_mode=control_info.resize_mode,
502+
)
503+
)
504+
505+
@staticmethod
506+
def parse_t2i_adapter_field(
507+
exit_stack: ExitStack,
508+
context: InvocationContext,
509+
t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
510+
ext_manager: ExtensionsManager,
511+
) -> None:
512+
if t2i_adapters is None:
513+
return
514+
515+
# Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField.
516+
if isinstance(t2i_adapters, T2IAdapterField):
517+
t2i_adapters = [t2i_adapters]
518+
519+
for t2i_adapter_field in t2i_adapters:
520+
ext_manager.add_extension(
521+
T2IAdapterExt(
522+
node_context=context,
523+
model_id=t2i_adapter_field.t2i_adapter_model,
524+
image=context.images.get_pil(t2i_adapter_field.image.image_name),
525+
weight=t2i_adapter_field.weight,
526+
begin_step_percent=t2i_adapter_field.begin_step_percent,
527+
end_step_percent=t2i_adapter_field.end_step_percent,
528+
resize_mode=t2i_adapter_field.resize_mode,
529+
)
530+
)
531+
468532
def prep_ip_adapter_image_prompts(
469533
self,
470534
context: InvocationContext,
@@ -773,6 +837,18 @@ def step_callback(state: PipelineIntermediateState) -> None:
773837

774838
ext_manager.add_extension(PreviewExt(step_callback))
775839

840+
### cfg rescale
841+
if self.cfg_rescale_multiplier > 0:
842+
ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier))
843+
844+
### freeu
845+
if self.unet.freeu_config:
846+
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
847+
848+
### seamless
849+
if self.unet.seamless_axes:
850+
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
851+
776852
### inpaint
777853
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
778854
# NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
@@ -788,7 +864,6 @@ def step_callback(state: PipelineIntermediateState) -> None:
788864
latents = latents.to(device=device, dtype=dtype)
789865
if noise is not None:
790866
noise = noise.to(device=device, dtype=dtype)
791-
792867
denoise_ctx = DenoiseContext(
793868
inputs=DenoiseInputs(
794869
orig_latents=latents,
@@ -804,22 +879,31 @@ def step_callback(state: PipelineIntermediateState) -> None:
804879
scheduler=scheduler,
805880
)
806881

807-
# ext: t2i/ip adapter
808-
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
809-
810-
unet_info = context.models.load(self.unet.unet)
811-
assert isinstance(unet_info.model, UNet2DConditionModel)
812-
with (
813-
unet_info.model_on_device() as (model_state_dict, unet),
814-
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
815-
# ext: controlnet
816-
ext_manager.patch_extensions(unet),
817-
# ext: freeu, seamless, ip adapter, lora
818-
ext_manager.patch_unet(model_state_dict, unet),
819-
):
820-
sd_backend = StableDiffusionBackend(unet, scheduler)
821-
denoise_ctx.unet = unet
822-
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
882+
# context for loading additional models
883+
with ExitStack() as exit_stack:
884+
# later should be smth like:
885+
# for extension_field in self.extensions:
886+
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
887+
# ext_manager.add_extension(ext)
888+
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
889+
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager)
890+
891+
# ext: t2i/ip adapter
892+
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
893+
894+
unet_info = context.models.load(self.unet.unet)
895+
assert isinstance(unet_info.model, UNet2DConditionModel)
896+
with (
897+
unet_info.model_on_device() as (cached_weights, unet),
898+
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
899+
# ext: controlnet
900+
ext_manager.patch_extensions(denoise_ctx),
901+
# ext: freeu, seamless, ip adapter, lora
902+
ext_manager.patch_unet(unet, cached_weights),
903+
):
904+
sd_backend = StableDiffusionBackend(unet, scheduler)
905+
denoise_ctx.unet = unet
906+
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
823907

824908
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
825909
result_latents = result_latents.detach().to("cpu")
@@ -882,7 +966,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
882966
ExitStack() as exit_stack,
883967
unet_info.model_on_device() as (model_state_dict, unet),
884968
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
885-
set_seamless(unet, self.unet.seamless_axes), # FIXME
969+
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
886970
# Apply the LoRA after unet has been moved to its target device for faster patching.
887971
ModelPatcher.apply_lora_unet(
888972
unet,

invokeai/app/invocations/latents_to_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from invokeai.app.invocations.model import VAEField
2525
from invokeai.app.invocations.primitives import ImageOutput
2626
from invokeai.app.services.shared.invocation_context import InvocationContext
27-
from invokeai.backend.stable_diffusion import set_seamless
27+
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
2828
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
2929
from invokeai.backend.util.devices import TorchDevice
3030

@@ -59,7 +59,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
5959

6060
vae_info = context.models.load(self.vae.vae)
6161
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
62-
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
62+
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
6363
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
6464
latents = latents.to(vae.device)
6565
if self.fp32:

0 commit comments

Comments
 (0)