37
37
from invokeai .app .util .controlnet_utils import prepare_control_image
38
38
from invokeai .backend .ip_adapter .ip_adapter import IPAdapter
39
39
from invokeai .backend .lora import LoRAModelRaw
40
- from invokeai .backend .model_manager import BaseModelType
40
+ from invokeai .backend .model_manager import BaseModelType , ModelVariantType
41
41
from invokeai .backend .model_patcher import ModelPatcher
42
- from invokeai .backend .stable_diffusion import PipelineIntermediateState , set_seamless
42
+ from invokeai .backend .stable_diffusion import PipelineIntermediateState
43
43
from invokeai .backend .stable_diffusion .denoise_context import DenoiseContext , DenoiseInputs
44
44
from invokeai .backend .stable_diffusion .diffusers_pipeline import (
45
45
ControlNetData ,
60
60
from invokeai .backend .stable_diffusion .extension_callback_type import ExtensionCallbackType
61
61
from invokeai .backend .stable_diffusion .extensions .controlnet import ControlNetExt
62
62
from invokeai .backend .stable_diffusion .extensions .freeu import FreeUExt
63
+ from invokeai .backend .stable_diffusion .extensions .inpaint import InpaintExt
64
+ from invokeai .backend .stable_diffusion .extensions .inpaint_model import InpaintModelExt
63
65
from invokeai .backend .stable_diffusion .extensions .preview import PreviewExt
64
66
from invokeai .backend .stable_diffusion .extensions .rescale_cfg import RescaleCFGExt
67
+ from invokeai .backend .stable_diffusion .extensions .seamless import SeamlessExt
68
+ from invokeai .backend .stable_diffusion .extensions .t2i_adapter import T2IAdapterExt
65
69
from invokeai .backend .stable_diffusion .extensions_manager import ExtensionsManager
66
70
from invokeai .backend .stable_diffusion .schedulers import SCHEDULER_MAP
67
71
from invokeai .backend .stable_diffusion .schedulers .schedulers import SCHEDULER_NAME_VALUES
@@ -498,6 +502,33 @@ def parse_controlnet_field(
498
502
)
499
503
)
500
504
505
+ @staticmethod
506
+ def parse_t2i_adapter_field (
507
+ exit_stack : ExitStack ,
508
+ context : InvocationContext ,
509
+ t2i_adapters : Optional [Union [T2IAdapterField , list [T2IAdapterField ]]],
510
+ ext_manager : ExtensionsManager ,
511
+ ) -> None :
512
+ if t2i_adapters is None :
513
+ return
514
+
515
+ # Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField.
516
+ if isinstance (t2i_adapters , T2IAdapterField ):
517
+ t2i_adapters = [t2i_adapters ]
518
+
519
+ for t2i_adapter_field in t2i_adapters :
520
+ ext_manager .add_extension (
521
+ T2IAdapterExt (
522
+ node_context = context ,
523
+ model_id = t2i_adapter_field .t2i_adapter_model ,
524
+ image = context .images .get_pil (t2i_adapter_field .image .image_name ),
525
+ weight = t2i_adapter_field .weight ,
526
+ begin_step_percent = t2i_adapter_field .begin_step_percent ,
527
+ end_step_percent = t2i_adapter_field .end_step_percent ,
528
+ resize_mode = t2i_adapter_field .resize_mode ,
529
+ )
530
+ )
531
+
501
532
def prep_ip_adapter_image_prompts (
502
533
self ,
503
534
context : InvocationContext ,
@@ -707,7 +738,7 @@ def prep_inpaint_mask(
707
738
else :
708
739
masked_latents = torch .where (mask < 0.5 , 0.0 , latents )
709
740
710
- return 1 - mask , masked_latents , self .denoise_mask .gradient
741
+ return mask , masked_latents , self .denoise_mask .gradient
711
742
712
743
@staticmethod
713
744
def prepare_noise_and_latents (
@@ -765,10 +796,6 @@ def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
765
796
dtype = TorchDevice .choose_torch_dtype ()
766
797
767
798
seed , noise , latents = self .prepare_noise_and_latents (context , self .noise , self .latents )
768
- latents = latents .to (device = device , dtype = dtype )
769
- if noise is not None :
770
- noise = noise .to (device = device , dtype = dtype )
771
-
772
799
_ , _ , latent_height , latent_width = latents .shape
773
800
774
801
conditioning_data = self .get_conditioning_data (
@@ -801,21 +828,6 @@ def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
801
828
denoising_end = self .denoising_end ,
802
829
)
803
830
804
- denoise_ctx = DenoiseContext (
805
- inputs = DenoiseInputs (
806
- orig_latents = latents ,
807
- timesteps = timesteps ,
808
- init_timestep = init_timestep ,
809
- noise = noise ,
810
- seed = seed ,
811
- scheduler_step_kwargs = scheduler_step_kwargs ,
812
- conditioning_data = conditioning_data ,
813
- attention_processor_cls = CustomAttnProcessor2_0 ,
814
- ),
815
- unet = None ,
816
- scheduler = scheduler ,
817
- )
818
-
819
831
# get the unet's config so that we can pass the base to sd_step_callback()
820
832
unet_config = context .models .get_config (self .unet .unet .key )
821
833
@@ -833,13 +845,48 @@ def step_callback(state: PipelineIntermediateState) -> None:
833
845
if self .unet .freeu_config :
834
846
ext_manager .add_extension (FreeUExt (self .unet .freeu_config ))
835
847
848
+ ### seamless
849
+ if self .unet .seamless_axes :
850
+ ext_manager .add_extension (SeamlessExt (self .unet .seamless_axes ))
851
+
852
+ ### inpaint
853
+ mask , masked_latents , is_gradient_mask = self .prep_inpaint_mask (context , latents )
854
+ # NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
855
+ # use the ModelVariantType config. During testing, there was a report of a user with models that had an
856
+ # incorrect ModelVariantType value. Re-installing the model fixed the issue. If this issue turns out to be
857
+ # prevalent, we will have to revisit how we initialize the inpainting extensions.
858
+ if unet_config .variant == ModelVariantType .Inpaint :
859
+ ext_manager .add_extension (InpaintModelExt (mask , masked_latents , is_gradient_mask ))
860
+ elif mask is not None :
861
+ ext_manager .add_extension (InpaintExt (mask , is_gradient_mask ))
862
+
863
+ # Initialize context for modular denoise
864
+ latents = latents .to (device = device , dtype = dtype )
865
+ if noise is not None :
866
+ noise = noise .to (device = device , dtype = dtype )
867
+ denoise_ctx = DenoiseContext (
868
+ inputs = DenoiseInputs (
869
+ orig_latents = latents ,
870
+ timesteps = timesteps ,
871
+ init_timestep = init_timestep ,
872
+ noise = noise ,
873
+ seed = seed ,
874
+ scheduler_step_kwargs = scheduler_step_kwargs ,
875
+ conditioning_data = conditioning_data ,
876
+ attention_processor_cls = CustomAttnProcessor2_0 ,
877
+ ),
878
+ unet = None ,
879
+ scheduler = scheduler ,
880
+ )
881
+
836
882
# context for loading additional models
837
883
with ExitStack () as exit_stack :
838
884
# later should be smth like:
839
885
# for extension_field in self.extensions:
840
886
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
841
887
# ext_manager.add_extension(ext)
842
888
self .parse_controlnet_field (exit_stack , context , self .control , ext_manager )
889
+ self .parse_t2i_adapter_field (exit_stack , context , self .t2i_adapter , ext_manager )
843
890
844
891
# ext: t2i/ip adapter
845
892
ext_manager .run_callback (ExtensionCallbackType .SETUP , denoise_ctx )
@@ -871,6 +918,10 @@ def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
871
918
seed , noise , latents = self .prepare_noise_and_latents (context , self .noise , self .latents )
872
919
873
920
mask , masked_latents , gradient_mask = self .prep_inpaint_mask (context , latents )
921
+ # At this point, the mask ranges from 0 (leave unchanged) to 1 (inpaint).
922
+ # We invert the mask here for compatibility with the old backend implementation.
923
+ if mask is not None :
924
+ mask = 1 - mask
874
925
875
926
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
876
927
# below. Investigate whether this is appropriate.
@@ -915,7 +966,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
915
966
ExitStack () as exit_stack ,
916
967
unet_info .model_on_device () as (model_state_dict , unet ),
917
968
ModelPatcher .apply_freeu (unet , self .unet .freeu_config ),
918
- set_seamless (unet , self .unet .seamless_axes ), # FIXME
969
+ SeamlessExt . static_patch_model (unet , self .unet .seamless_axes ), # FIXME
919
970
# Apply the LoRA after unet has been moved to its target device for faster patching.
920
971
ModelPatcher .apply_lora_unet (
921
972
unet ,
0 commit comments