@@ -76,7 +76,7 @@ class AttentionSharingUnit(torch.nn.Module):
7676 # call.
7777 transformer_options : dict = {}
7878
79- def __init__ (self , module , frames = 2 , use_control = True , rank = 256 ):
79+ def __init__ (self , module , frames = 2 , control_signals = None , rank = 256 ):
8080 super ().__init__ ()
8181
8282 self .heads = module .heads
@@ -142,9 +142,9 @@ def __init__(self, module, frames=2, use_control=True, rank=256):
142142 in_features = hidden_size , out_features = hidden_size
143143 )
144144
145+ self .control_signals = control_signals
145146 self .control_convs = None
146-
147- if use_control :
147+ if control_signals is not None :
148148 self .control_convs = [
149149 torch .nn .Sequential (
150150 torch .nn .Conv2d (256 , 256 , kernel_size = 3 , padding = 1 , stride = 1 ),
@@ -155,7 +155,6 @@ def __init__(self, module, frames=2, use_control=True, rank=256):
155155 ]
156156 self .control_convs = torch .nn .ModuleList (self .control_convs )
157157
158- self .control_signals = None
159158
160159 def forward (self , h , context = None , value = None ):
161160 transformer_options = self .transformer_options
@@ -325,36 +324,29 @@ def __init__(self, layer_list):
325324
326325
327326class AttentionSharingPatcher (torch .nn .Module ):
328- def __init__ (self , unet , frames = 2 , use_control = True , rank = 256 ):
327+ def __init__ (self , unet , frames = 2 , control_img = None , rank = 256 ):
329328 super ().__init__ ()
330- model_management .unload_model_clones (unet )
329+ control_signals = (
330+ AdditionalAttentionCondsEncoder ()(control_img .cpu ().float () * 2.0 - 1.0 )
331+ if control_img is not None
332+ else None
333+ )
331334
332335 units = []
333336 for i in range (32 ):
334337 real_key = module_mapping_sd15 [i ]
335338 attn_module = utils .get_attr (unet .model .diffusion_model , real_key )
336339 u = AttentionSharingUnit (
337- attn_module , frames = frames , use_control = use_control , rank = rank
340+ attn_module ,
341+ frames = frames ,
342+ control_signals = control_signals ,
343+ rank = rank ,
338344 )
339345 units .append (u )
340346 unet .add_object_patch ("diffusion_model." + real_key , u )
341-
342347 self .hookers = HookerLayers (units )
343348
344- if use_control :
345- self .kwargs_encoder = AdditionalAttentionCondsEncoder ()
346- else :
347- self .kwargs_encoder = None
348-
349349 self .dtype = torch .float32
350350 if model_management .should_use_fp16 (model_management .get_torch_device ()):
351351 self .dtype = torch .float16
352352 self .hookers .half ()
353- return
354-
355- def set_control (self , img ):
356- img = img .cpu ().float () * 2.0 - 1.0
357- signals = self .kwargs_encoder (img )
358- for m in self .hookers .layers :
359- m .control_signals = signals
360- return
0 commit comments