Skip to content

Commit ba0e74e

Browse files
authored
Merge PR #198 from Kosinkadink/rework-modelpatcher
Rework ModelPatcher for upcoming ComfyUI update
2 parents 172543b + 3ad082b commit ba0e74e

19 files changed

+1535
-1620
lines changed

__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from .adv_control.nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
22
from .adv_control import documentation
3+
from .adv_control.dinklink import init_dinklink
4+
from .adv_control.sampling import prepare_dinklink_acn_wrapper
35

46
WEB_DIRECTORY = "./web"
57
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', "WEB_DIRECTORY"]
68
documentation.format_descriptions(NODE_CLASS_MAPPINGS)
9+
10+
init_dinklink()
11+
prepare_dinklink_acn_wrapper()

adv_control/control.py

Lines changed: 149 additions & 85 deletions
Large diffs are not rendered by default.

adv_control/control_ctrlora.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
import torch
2+
from torch import Tensor
3+
4+
from comfy.cldm.cldm import ControlNet as ControlNetCLDM
5+
import comfy.model_detection
6+
import comfy.model_management
7+
import comfy.ops
8+
import comfy.utils
9+
10+
from comfy.ldm.modules.diffusionmodules.util import (
11+
zero_module,
12+
timestep_embedding,
13+
)
14+
15+
from .control import ControlNetAdvanced
16+
from .utils import TimestepKeyframeGroup
17+
from .logger import logger
18+
19+
20+
class ControlNetCtrLoRA(ControlNetCLDM):
21+
def __init__(self, *args, **kwargs):
22+
super().__init__(*args, **kwargs)
23+
# delete input hint block
24+
del self.input_hint_block
25+
26+
def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs):
27+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
28+
emb = self.time_embed(t_emb)
29+
30+
out_output = []
31+
out_middle = []
32+
33+
if self.num_classes is not None:
34+
assert y.shape[0] == x.shape[0]
35+
emb = emb + self.label_emb(y)
36+
37+
h = hint.to(dtype=x.dtype)
38+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
39+
h = module(h, emb, context)
40+
out_output.append(zero_conv(h, emb, context))
41+
42+
h = self.middle_block(h, emb, context)
43+
out_middle.append(self.middle_block_out(h, emb, context))
44+
45+
return {"middle": out_middle, "output": out_output}
46+
47+
48+
class CtrLoRAAdvanced(ControlNetAdvanced):
49+
def __init__(self, *args, **kwargs):
50+
super().__init__(*args, **kwargs)
51+
self.require_vae = True
52+
self.mult_by_ratio_when_vae = False
53+
54+
def pre_run_advanced(self, model, percent_to_timestep_function):
55+
super().pre_run_advanced(model, percent_to_timestep_function)
56+
self.latent_format = model.latent_format # LatentFormat object, used to process_in latent cond hint
57+
58+
def cleanup_advanced(self):
59+
super().cleanup_advanced()
60+
if self.latent_format is not None:
61+
del self.latent_format
62+
self.latent_format = None
63+
64+
def copy(self):
65+
c = CtrLoRAAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
66+
c.control_model = self.control_model
67+
c.control_model_wrapped = self.control_model_wrapped
68+
self.copy_to(c)
69+
self.copy_to_advanced(c)
70+
return c
71+
72+
73+
def load_ctrlora(base_path: str, lora_path: str,
74+
base_data: dict[str, Tensor]=None, lora_data: dict[str, Tensor]=None,
75+
timestep_keyframe: TimestepKeyframeGroup=None, model=None, model_options={}):
76+
if base_data is None:
77+
base_data = comfy.utils.load_torch_file(base_path, safe_load=True)
78+
controlnet_data = base_data
79+
80+
# first, check that base_data contains keys with lora_layer
81+
contains_lora_layers = False
82+
for key in base_data:
83+
if "lora_layer" in key:
84+
contains_lora_layers = True
85+
if not contains_lora_layers:
86+
raise Exception(f"File '{base_path}' is not a valid CtrLoRA base model; does not contain any lora_layer keys.")
87+
88+
controlnet_config = None
89+
supported_inference_dtypes = None
90+
91+
pth_key = 'control_model.zero_convs.0.0.weight'
92+
pth = False
93+
key = 'zero_convs.0.0.weight'
94+
if pth_key in controlnet_data:
95+
pth = True
96+
key = pth_key
97+
prefix = "control_model."
98+
elif key in controlnet_data:
99+
prefix = ""
100+
else:
101+
raise Exception("")
102+
net = load_t2i_adapter(controlnet_data, model_options=model_options)
103+
if net is None:
104+
logging.error("error could not detect control model type.")
105+
return net
106+
107+
if controlnet_config is None:
108+
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
109+
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
110+
controlnet_config = model_config.unet_config
111+
112+
unet_dtype = model_options.get("dtype", None)
113+
if unet_dtype is None:
114+
weight_dtype = comfy.utils.weight_dtype(controlnet_data)
115+
116+
if supported_inference_dtypes is None:
117+
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
118+
119+
if weight_dtype is not None:
120+
supported_inference_dtypes.append(weight_dtype)
121+
122+
unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
123+
124+
load_device = comfy.model_management.get_torch_device()
125+
126+
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
127+
operations = model_options.get("custom_operations", None)
128+
if operations is None:
129+
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype)
130+
131+
controlnet_config["operations"] = operations
132+
controlnet_config["dtype"] = unet_dtype
133+
controlnet_config["device"] = comfy.model_management.unet_offload_device()
134+
controlnet_config.pop("out_channels")
135+
controlnet_config["hint_channels"] = 3
136+
#controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
137+
control_model = ControlNetCtrLoRA(**controlnet_config)
138+
139+
if pth:
140+
if 'difference' in controlnet_data:
141+
if model is not None:
142+
comfy.model_management.load_models_gpu([model])
143+
model_sd = model.model_state_dict()
144+
for x in controlnet_data:
145+
c_m = "control_model."
146+
if x.startswith(c_m):
147+
sd_key = "diffusion_model.{}".format(x[len(c_m):])
148+
if sd_key in model_sd:
149+
cd = controlnet_data[x]
150+
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
151+
else:
152+
logger.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
153+
154+
class WeightsLoader(torch.nn.Module):
155+
pass
156+
w = WeightsLoader()
157+
w.control_model = control_model
158+
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
159+
else:
160+
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
161+
162+
if len(missing) > 0:
163+
logger.warning("missing controlnet keys: {}".format(missing))
164+
165+
if len(unexpected) > 0:
166+
logger.debug("unexpected controlnet keys: {}".format(unexpected))
167+
168+
global_average_pooling = model_options.get("global_average_pooling", False)
169+
control = CtrLoRAAdvanced(control_model, timestep_keyframe, global_average_pooling=global_average_pooling,
170+
load_device=load_device, manual_cast_dtype=manual_cast_dtype)
171+
# load lora data onto the controlnet
172+
if lora_path is not None:
173+
load_lora_data(control, lora_path)
174+
175+
return control
176+
177+
178+
def load_lora_data(control: CtrLoRAAdvanced, lora_path: str, loaded_data: dict[str, Tensor]=None, lora_strength=1.0):
179+
if loaded_data is None:
180+
loaded_data = comfy.utils.load_torch_file(lora_path, safe_load=True)
181+
# check that lora_data contains keys with lora_layer
182+
contains_lora_layers = False
183+
for key in loaded_data:
184+
if "lora_layer" in key:
185+
contains_lora_layers = True
186+
if not contains_lora_layers:
187+
raise Exception(f"File '{lora_path}' is not a valid CtrLoRA lora model; does not contain any lora_layer keys.")
188+
189+
# now that we know we have a ctrlora file, separate keys into 'set' and 'lora' keys
190+
data_set: dict[str, Tensor] = {}
191+
data_lora: dict[str, Tensor] = {}
192+
193+
for key in list(loaded_data.keys()):
194+
if 'lora_layer' in key:
195+
data_lora[key] = loaded_data.pop(key)
196+
else:
197+
data_set[key] = loaded_data.pop(key)
198+
# no keys should be left over
199+
if len(loaded_data) > 0:
200+
logger.warning("Not all keys from CtrlLoRA lora model's loaded data were parsed!")
201+
202+
# turn set/lora data into corresponding patches;
203+
patches = {}
204+
# set will replace the values
205+
for key, value in data_set.items():
206+
# prase model key from key;
207+
# remove "control_model."
208+
model_key = key.replace("control_model.", "")
209+
patches[model_key] = ("set", (value,))
210+
# lora will do mm of up and down tensors
211+
for down_key in data_lora:
212+
# only process lora down keys; we will process both up+down at the same time
213+
if ".up." in key:
214+
continue
215+
# get up version of down key
216+
up_key = down_key.replace(".down.", ".up.")
217+
# get key that will match up with model key;
218+
# remove "lora_layer.down." and "control_model."
219+
model_key = down_key.replace("lora_layer.down.", "").replace("control_model.", "")
220+
221+
weight_down = data_lora[down_key]
222+
weight_up = data_lora[up_key]
223+
# currently, ComfyUI expects 6 elements in 'lora' type, but for future-proofing add a bunch more with None
224+
patches[model_key] = ("lora", (weight_up, weight_down, None, None, None, None,
225+
None, None, None, None, None, None, None, None))
226+
227+
# now that patches are made, add them to model
228+
control.control_model_wrapped.add_patches(patches, strength_patch=lora_strength)

adv_control/control_lllite.py

Lines changed: 12 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515

1616
from .logger import logger
1717
from .utils import (AdvancedControlBase, TimestepKeyframeGroup, ControlWeights, broadcast_image_to_extend, extend_to_batch_size,
18-
deepcopy_with_sharing, prepare_mask_batch)
18+
prepare_mask_batch)
1919

2020

2121
# based on set_model_patch code in comfy/model_patcher.py
22-
def set_model_patch(model_options, patch, name):
23-
to = model_options["transformer_options"]
22+
def set_model_patch(transformer_options, patch, name):
23+
to = transformer_options
2424
# check if patch was already added
2525
if "patches" in to:
2626
current_patches = to["patches"].get(name, [])
@@ -30,11 +30,11 @@ def set_model_patch(model_options, patch, name):
3030
to["patches"] = {}
3131
to["patches"][name] = to["patches"].get(name, []) + [patch]
3232

33-
def set_model_attn1_patch(model_options, patch):
34-
set_model_patch(model_options, patch, "attn1_patch")
33+
def set_model_attn1_patch(transformer_options, patch):
34+
set_model_patch(transformer_options, patch, "attn1_patch")
3535

36-
def set_model_attn2_patch(model_options, patch):
37-
set_model_patch(model_options, patch, "attn2_patch")
36+
def set_model_attn2_patch(transformer_options, patch):
37+
set_model_patch(transformer_options, patch, "attn2_patch")
3838

3939

4040
def extra_options_to_module_prefix(extra_options):
@@ -115,26 +115,8 @@ def clone_with_control(self, control: AdvancedControlBase):
115115
return LLLitePatch(self.modules, self.patch_type, control)
116116

117117
def cleanup(self):
118-
#total_cleaned = 0
119118
for module in self.modules.values():
120119
module.cleanup()
121-
# total_cleaned += 1
122-
#logger.info(f"cleaned modules: {total_cleaned}, {id(self)}")
123-
#logger.error(f"cleanup LLLitePatch: {id(self)}")
124-
125-
# make sure deepcopy does not copy control, and deepcopied LLLitePatch should be assigned to control
126-
# def __deepcopy__(self, memo):
127-
# self.cleanup()
128-
# to_return: LLLitePatch = deepcopy_with_sharing(self, shared_attribute_names = ['control'], memo=memo)
129-
# #logger.warn(f"patch {id(self)} turned into {id(to_return)}")
130-
# try:
131-
# if self.patch_type == self.ATTN1:
132-
# to_return.control.patch_attn1 = to_return
133-
# elif self.patch_type == self.ATTN2:
134-
# to_return.control.patch_attn2 = to_return
135-
# except Exception:
136-
# pass
137-
# return to_return
138120

139121

140122
# TODO: use comfy.ops to support fp8 properly
@@ -298,14 +280,6 @@ def __init__(self, patch_attn1: LLLitePatch, patch_attn2: LLLitePatch, timestep_
298280
self.latent_dims_div2 = None
299281
self.latent_dims_div4 = None
300282

301-
def live_model_patches(self, model_options):
302-
set_model_attn1_patch(model_options, self.patch_attn1.set_control(self))
303-
set_model_attn2_patch(model_options, self.patch_attn2.set_control(self))
304-
305-
# def patch_model(self, model: ModelPatcher):
306-
# model.set_model_attn1_patch(self.patch_attn1)
307-
# model.set_model_attn2_patch(self.patch_attn2)
308-
309283
def set_cond_hint_inject(self, *args, **kwargs):
310284
to_return = super().set_cond_hint_inject(*args, **kwargs)
311285
# cond hint for LLLite needs to be scaled between (-1, 1) instead of (0, 1)
@@ -319,11 +293,11 @@ def pre_run_advanced(self, *args, **kwargs):
319293
self.patch_attn2.set_control(self)
320294
#logger.warn(f"in pre_run_advanced: {id(self)}")
321295

322-
def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
296+
def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int, transformer_options: dict):
323297
# normal ControlNet stuff
324298
control_prev = None
325299
if self.previous_controlnet is not None:
326-
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
300+
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
327301

328302
if self.timestep_range is not None:
329303
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
@@ -372,7 +346,9 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
372346
self.latent_dims_div4 = (new_h, new_w)
373347
# prepare mask
374348
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number)
375-
# done preparing; model patches will take care of everything now.
349+
# done preparing; model patches will take care of everything now
350+
set_model_attn1_patch(transformer_options, self.patch_attn1.set_control(self))
351+
set_model_attn2_patch(transformer_options, self.patch_attn2.set_control(self))
376352
# return normal controlnet stuff
377353
return control_prev
378354

@@ -393,17 +369,6 @@ def copy(self):
393369
self.copy_to(c)
394370
self.copy_to_advanced(c)
395371
return c
396-
397-
# deepcopy needs to properly keep track of objects to work between model.clone calls!
398-
# def __deepcopy__(self, *args, **kwargs):
399-
# self.cleanup_advanced()
400-
# return self
401-
402-
# def get_models(self):
403-
# # get_models is called once at the start of every KSampler run - use to reset already_patched status
404-
# out = super().get_models()
405-
# logger.error(f"in get_models! {id(self)}")
406-
# return out
407372

408373

409374
def load_controllllite(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None):

adv_control/control_plusplus.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import comfy.model_detection
2323
import comfy.utils
2424

25-
from .utils import (AdvancedControlBase, ControlWeights, ControlWeightType, TimestepKeyframeGroup, AbstractPreprocWrapper,
25+
from .utils import (AdvancedControlBase, ControlWeights, ControlWeightType, TimestepKeyframeGroup, AbstractPreprocWrapper, Extras,
2626
extend_to_batch_size, broadcast_image_to_extend)
2727
from .logger import logger
2828

@@ -239,7 +239,7 @@ def __init__(self, control_model: ControlNetPlusPlus, timestep_keyframes: Timest
239239
def get_universal_weights(self) -> ControlWeights:
240240
def cn_weights_func(idx: int, control: dict[str, list[Tensor]], key: str):
241241
if key == "middle":
242-
return 1.0
242+
return 1.0 * self.weights.extras.get(Extras.MIDDLE_MULT, 1.0)
243243
c_len = len(control[key])
244244
raw_weights = [(self.weights.base_multiplier ** float((c_len) - i)) for i in range(c_len+1)]
245245
raw_weights = raw_weights[:-1]
@@ -276,10 +276,10 @@ def set_cond_hint_inject(self, *args, **kwargs):
276276
self.cond_hint_original = pp_group
277277
return to_return
278278

279-
def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number):
279+
def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number, transformer_options):
280280
control_prev = None
281281
if self.previous_controlnet is not None:
282-
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
282+
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
283283

284284
if self.timestep_range is not None:
285285
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:

0 commit comments

Comments
 (0)