8
8
from PIL .Image import Image
9
9
10
10
from invokeai .app .invocations .constants import LATENT_SCALE_FACTOR
11
- from invokeai .app .util .controlnet_utils import prepare_control_image
11
+ from invokeai .app .util .controlnet_utils import CONTROLNET_MODE_VALUES , CONTROLNET_RESIZE_VALUES , prepare_control_image
12
12
from invokeai .backend .stable_diffusion .denoise_context import UNetKwargs
13
13
from invokeai .backend .stable_diffusion .diffusion .conditioning_data import ConditioningMode
14
14
from invokeai .backend .stable_diffusion .extension_callback_type import ExtensionCallbackType
@@ -27,8 +27,8 @@ def __init__(
27
27
weight : Union [float , List [float ]],
28
28
begin_step_percent : float ,
29
29
end_step_percent : float ,
30
- control_mode : str ,
31
- resize_mode : str ,
30
+ control_mode : CONTROLNET_MODE_VALUES ,
31
+ resize_mode : CONTROLNET_RESIZE_VALUES ,
32
32
):
33
33
super ().__init__ ()
34
34
self ._model = model
@@ -43,8 +43,8 @@ def __init__(
43
43
44
44
@contextmanager
45
45
def patch_extension (self , ctx : DenoiseContext ):
46
+ original_processors = self ._model .attn_processors
46
47
try :
47
- original_processors = self ._model .attn_processors
48
48
self ._model .set_attn_processor (ctx .inputs .attention_processor_cls ())
49
49
50
50
yield None
@@ -62,8 +62,6 @@ def resize_image(self, ctx: DenoiseContext):
62
62
do_classifier_free_guidance = False ,
63
63
width = image_width ,
64
64
height = image_height ,
65
- # batch_size=batch_size * num_images_per_prompt,
66
- # num_images_per_prompt=num_images_per_prompt,
67
65
device = ctx .latents .device ,
68
66
dtype = ctx .latents .dtype ,
69
67
control_mode = self ._control_mode ,
@@ -125,7 +123,7 @@ def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: Con
125
123
cn_unet_kwargs = UNetKwargs (
126
124
sample = model_input ,
127
125
timestep = ctx .timestep ,
128
- encoder_hidden_states = None , # set later by conditoning
126
+ encoder_hidden_states = None , # set later by conditioning
129
127
cross_attention_kwargs = dict ( # noqa: C408
130
128
percent_through = ctx .step_index / total_steps ,
131
129
),
@@ -139,9 +137,14 @@ def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: Con
139
137
weight = weight [ctx .step_index ]
140
138
141
139
tmp_kwargs = vars (cn_unet_kwargs )
142
- tmp_kwargs .pop ("down_block_additional_residuals" , None )
143
- tmp_kwargs .pop ("mid_block_additional_residual" , None )
144
- tmp_kwargs .pop ("down_intrablock_additional_residuals" , None )
140
+
141
+ # Remove kwargs not related to ControlNet unet
142
+ # ControlNet guidance fields
143
+ del tmp_kwargs ["down_block_additional_residuals" ]
144
+ del tmp_kwargs ["mid_block_additional_residual" ]
145
+
146
+ # T2i Adapter guidance fields
147
+ del tmp_kwargs ["down_intrablock_additional_residuals" ]
145
148
146
149
# controlnet(s) inference
147
150
down_samples , mid_sample = self ._model (
0 commit comments