39
39
from invokeai .backend .lora import LoRAModelRaw
40
40
from invokeai .backend .model_manager import BaseModelType , ModelVariantType
41
41
from invokeai .backend .model_patcher import ModelPatcher
42
- from invokeai .backend .stable_diffusion import PipelineIntermediateState , set_seamless
42
+ from invokeai .backend .stable_diffusion import PipelineIntermediateState
43
43
from invokeai .backend .stable_diffusion .denoise_context import DenoiseContext , DenoiseInputs
44
44
from invokeai .backend .stable_diffusion .diffusers_pipeline import (
45
45
ControlNetData ,
58
58
from invokeai .backend .stable_diffusion .diffusion .custom_atttention import CustomAttnProcessor2_0
59
59
from invokeai .backend .stable_diffusion .diffusion_backend import StableDiffusionBackend
60
60
from invokeai .backend .stable_diffusion .extension_callback_type import ExtensionCallbackType
61
+ from invokeai .backend .stable_diffusion .extensions .controlnet import ControlNetExt
62
+ from invokeai .backend .stable_diffusion .extensions .freeu import FreeUExt
61
63
from invokeai .backend .stable_diffusion .extensions .inpaint import InpaintExt
62
64
from invokeai .backend .stable_diffusion .extensions .inpaint_model import InpaintModelExt
63
65
from invokeai .backend .stable_diffusion .extensions .preview import PreviewExt
66
+ from invokeai .backend .stable_diffusion .extensions .rescale_cfg import RescaleCFGExt
67
+ from invokeai .backend .stable_diffusion .extensions .seamless import SeamlessExt
68
+ from invokeai .backend .stable_diffusion .extensions .t2i_adapter import T2IAdapterExt
64
69
from invokeai .backend .stable_diffusion .extensions_manager import ExtensionsManager
65
70
from invokeai .backend .stable_diffusion .schedulers import SCHEDULER_MAP
66
71
from invokeai .backend .stable_diffusion .schedulers .schedulers import SCHEDULER_NAME_VALUES
@@ -465,6 +470,65 @@ def prep_control_data(
465
470
466
471
return controlnet_data
467
472
473
+ @staticmethod
474
+ def parse_controlnet_field (
475
+ exit_stack : ExitStack ,
476
+ context : InvocationContext ,
477
+ control_input : ControlField | list [ControlField ] | None ,
478
+ ext_manager : ExtensionsManager ,
479
+ ) -> None :
480
+ # Normalize control_input to a list.
481
+ control_list : list [ControlField ]
482
+ if isinstance (control_input , ControlField ):
483
+ control_list = [control_input ]
484
+ elif isinstance (control_input , list ):
485
+ control_list = control_input
486
+ elif control_input is None :
487
+ control_list = []
488
+ else :
489
+ raise ValueError (f"Unexpected control_input type: { type (control_input )} " )
490
+
491
+ for control_info in control_list :
492
+ model = exit_stack .enter_context (context .models .load (control_info .control_model ))
493
+ ext_manager .add_extension (
494
+ ControlNetExt (
495
+ model = model ,
496
+ image = context .images .get_pil (control_info .image .image_name ),
497
+ weight = control_info .control_weight ,
498
+ begin_step_percent = control_info .begin_step_percent ,
499
+ end_step_percent = control_info .end_step_percent ,
500
+ control_mode = control_info .control_mode ,
501
+ resize_mode = control_info .resize_mode ,
502
+ )
503
+ )
504
+
505
+ @staticmethod
506
+ def parse_t2i_adapter_field (
507
+ exit_stack : ExitStack ,
508
+ context : InvocationContext ,
509
+ t2i_adapters : Optional [Union [T2IAdapterField , list [T2IAdapterField ]]],
510
+ ext_manager : ExtensionsManager ,
511
+ ) -> None :
512
+ if t2i_adapters is None :
513
+ return
514
+
515
+ # Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField.
516
+ if isinstance (t2i_adapters , T2IAdapterField ):
517
+ t2i_adapters = [t2i_adapters ]
518
+
519
+ for t2i_adapter_field in t2i_adapters :
520
+ ext_manager .add_extension (
521
+ T2IAdapterExt (
522
+ node_context = context ,
523
+ model_id = t2i_adapter_field .t2i_adapter_model ,
524
+ image = context .images .get_pil (t2i_adapter_field .image .image_name ),
525
+ weight = t2i_adapter_field .weight ,
526
+ begin_step_percent = t2i_adapter_field .begin_step_percent ,
527
+ end_step_percent = t2i_adapter_field .end_step_percent ,
528
+ resize_mode = t2i_adapter_field .resize_mode ,
529
+ )
530
+ )
531
+
468
532
def prep_ip_adapter_image_prompts (
469
533
self ,
470
534
context : InvocationContext ,
@@ -773,6 +837,18 @@ def step_callback(state: PipelineIntermediateState) -> None:
773
837
774
838
ext_manager .add_extension (PreviewExt (step_callback ))
775
839
840
+ ### cfg rescale
841
+ if self .cfg_rescale_multiplier > 0 :
842
+ ext_manager .add_extension (RescaleCFGExt (self .cfg_rescale_multiplier ))
843
+
844
+ ### freeu
845
+ if self .unet .freeu_config :
846
+ ext_manager .add_extension (FreeUExt (self .unet .freeu_config ))
847
+
848
+ ### seamless
849
+ if self .unet .seamless_axes :
850
+ ext_manager .add_extension (SeamlessExt (self .unet .seamless_axes ))
851
+
776
852
### inpaint
777
853
mask , masked_latents , is_gradient_mask = self .prep_inpaint_mask (context , latents )
778
854
# NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
@@ -788,7 +864,6 @@ def step_callback(state: PipelineIntermediateState) -> None:
788
864
latents = latents .to (device = device , dtype = dtype )
789
865
if noise is not None :
790
866
noise = noise .to (device = device , dtype = dtype )
791
-
792
867
denoise_ctx = DenoiseContext (
793
868
inputs = DenoiseInputs (
794
869
orig_latents = latents ,
@@ -804,22 +879,31 @@ def step_callback(state: PipelineIntermediateState) -> None:
804
879
scheduler = scheduler ,
805
880
)
806
881
807
- # ext: t2i/ip adapter
808
- ext_manager .run_callback (ExtensionCallbackType .SETUP , denoise_ctx )
809
-
810
- unet_info = context .models .load (self .unet .unet )
811
- assert isinstance (unet_info .model , UNet2DConditionModel )
812
- with (
813
- unet_info .model_on_device () as (model_state_dict , unet ),
814
- ModelPatcher .patch_unet_attention_processor (unet , denoise_ctx .inputs .attention_processor_cls ),
815
- # ext: controlnet
816
- ext_manager .patch_extensions (unet ),
817
- # ext: freeu, seamless, ip adapter, lora
818
- ext_manager .patch_unet (model_state_dict , unet ),
819
- ):
820
- sd_backend = StableDiffusionBackend (unet , scheduler )
821
- denoise_ctx .unet = unet
822
- result_latents = sd_backend .latents_from_embeddings (denoise_ctx , ext_manager )
882
+ # context for loading additional models
883
+ with ExitStack () as exit_stack :
884
+ # later should be smth like:
885
+ # for extension_field in self.extensions:
886
+ # ext = extension_field.to_extension(exit_stack, context, ext_manager)
887
+ # ext_manager.add_extension(ext)
888
+ self .parse_controlnet_field (exit_stack , context , self .control , ext_manager )
889
+ self .parse_t2i_adapter_field (exit_stack , context , self .t2i_adapter , ext_manager )
890
+
891
+ # ext: t2i/ip adapter
892
+ ext_manager .run_callback (ExtensionCallbackType .SETUP , denoise_ctx )
893
+
894
+ unet_info = context .models .load (self .unet .unet )
895
+ assert isinstance (unet_info .model , UNet2DConditionModel )
896
+ with (
897
+ unet_info .model_on_device () as (cached_weights , unet ),
898
+ ModelPatcher .patch_unet_attention_processor (unet , denoise_ctx .inputs .attention_processor_cls ),
899
+ # ext: controlnet
900
+ ext_manager .patch_extensions (denoise_ctx ),
901
+ # ext: freeu, seamless, ip adapter, lora
902
+ ext_manager .patch_unet (unet , cached_weights ),
903
+ ):
904
+ sd_backend = StableDiffusionBackend (unet , scheduler )
905
+ denoise_ctx .unet = unet
906
+ result_latents = sd_backend .latents_from_embeddings (denoise_ctx , ext_manager )
823
907
824
908
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
825
909
result_latents = result_latents .detach ().to ("cpu" )
@@ -882,7 +966,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
882
966
ExitStack () as exit_stack ,
883
967
unet_info .model_on_device () as (model_state_dict , unet ),
884
968
ModelPatcher .apply_freeu (unet , self .unet .freeu_config ),
885
- set_seamless (unet , self .unet .seamless_axes ), # FIXME
969
+ SeamlessExt . static_patch_model (unet , self .unet .seamless_axes ), # FIXME
886
970
# Apply the LoRA after unet has been moved to its target device for faster patching.
887
971
ModelPatcher .apply_lora_unet (
888
972
unet ,
0 commit comments