|
| 1 | +import torch |
| 2 | +from torch import Tensor |
| 3 | + |
| 4 | +from comfy.cldm.cldm import ControlNet as ControlNetCLDM |
| 5 | +import comfy.model_detection |
| 6 | +import comfy.model_management |
| 7 | +import comfy.ops |
| 8 | +import comfy.utils |
| 9 | + |
| 10 | +from comfy.ldm.modules.diffusionmodules.util import ( |
| 11 | + zero_module, |
| 12 | + timestep_embedding, |
| 13 | +) |
| 14 | + |
| 15 | +from .control import ControlNetAdvanced |
| 16 | +from .utils import TimestepKeyframeGroup |
| 17 | +from .logger import logger |
| 18 | + |
| 19 | + |
| 20 | +class ControlNetCtrLoRA(ControlNetCLDM): |
| 21 | + def __init__(self, *args, **kwargs): |
| 22 | + super().__init__(*args, **kwargs) |
| 23 | + # delete input hint block |
| 24 | + del self.input_hint_block |
| 25 | + |
| 26 | + def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs): |
| 27 | + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) |
| 28 | + emb = self.time_embed(t_emb) |
| 29 | + |
| 30 | + out_output = [] |
| 31 | + out_middle = [] |
| 32 | + |
| 33 | + if self.num_classes is not None: |
| 34 | + assert y.shape[0] == x.shape[0] |
| 35 | + emb = emb + self.label_emb(y) |
| 36 | + |
| 37 | + h = hint.to(dtype=x.dtype) |
| 38 | + for module, zero_conv in zip(self.input_blocks, self.zero_convs): |
| 39 | + h = module(h, emb, context) |
| 40 | + out_output.append(zero_conv(h, emb, context)) |
| 41 | + |
| 42 | + h = self.middle_block(h, emb, context) |
| 43 | + out_middle.append(self.middle_block_out(h, emb, context)) |
| 44 | + |
| 45 | + return {"middle": out_middle, "output": out_output} |
| 46 | + |
| 47 | + |
| 48 | +class CtrLoRAAdvanced(ControlNetAdvanced): |
| 49 | + def __init__(self, *args, **kwargs): |
| 50 | + super().__init__(*args, **kwargs) |
| 51 | + self.require_vae = True |
| 52 | + self.mult_by_ratio_when_vae = False |
| 53 | + |
| 54 | + def pre_run_advanced(self, model, percent_to_timestep_function): |
| 55 | + super().pre_run_advanced(model, percent_to_timestep_function) |
| 56 | + self.latent_format = model.latent_format # LatentFormat object, used to process_in latent cond hint |
| 57 | + |
| 58 | + def cleanup_advanced(self): |
| 59 | + super().cleanup_advanced() |
| 60 | + if self.latent_format is not None: |
| 61 | + del self.latent_format |
| 62 | + self.latent_format = None |
| 63 | + |
| 64 | + def copy(self): |
| 65 | + c = CtrLoRAAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) |
| 66 | + c.control_model = self.control_model |
| 67 | + c.control_model_wrapped = self.control_model_wrapped |
| 68 | + self.copy_to(c) |
| 69 | + self.copy_to_advanced(c) |
| 70 | + return c |
| 71 | + |
| 72 | + |
| 73 | +def load_ctrlora(base_path: str, lora_path: str, |
| 74 | + base_data: dict[str, Tensor]=None, lora_data: dict[str, Tensor]=None, |
| 75 | + timestep_keyframe: TimestepKeyframeGroup=None, model=None, model_options={}): |
| 76 | + if base_data is None: |
| 77 | + base_data = comfy.utils.load_torch_file(base_path, safe_load=True) |
| 78 | + controlnet_data = base_data |
| 79 | + |
| 80 | + # first, check that base_data contains keys with lora_layer |
| 81 | + contains_lora_layers = False |
| 82 | + for key in base_data: |
| 83 | + if "lora_layer" in key: |
| 84 | + contains_lora_layers = True |
| 85 | + if not contains_lora_layers: |
| 86 | + raise Exception(f"File '{base_path}' is not a valid CtrLoRA base model; does not contain any lora_layer keys.") |
| 87 | + |
| 88 | + controlnet_config = None |
| 89 | + supported_inference_dtypes = None |
| 90 | + |
| 91 | + pth_key = 'control_model.zero_convs.0.0.weight' |
| 92 | + pth = False |
| 93 | + key = 'zero_convs.0.0.weight' |
| 94 | + if pth_key in controlnet_data: |
| 95 | + pth = True |
| 96 | + key = pth_key |
| 97 | + prefix = "control_model." |
| 98 | + elif key in controlnet_data: |
| 99 | + prefix = "" |
| 100 | + else: |
| 101 | + raise Exception("") |
| 102 | + net = load_t2i_adapter(controlnet_data, model_options=model_options) |
| 103 | + if net is None: |
| 104 | + logging.error("error could not detect control model type.") |
| 105 | + return net |
| 106 | + |
| 107 | + if controlnet_config is None: |
| 108 | + model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True) |
| 109 | + supported_inference_dtypes = list(model_config.supported_inference_dtypes) |
| 110 | + controlnet_config = model_config.unet_config |
| 111 | + |
| 112 | + unet_dtype = model_options.get("dtype", None) |
| 113 | + if unet_dtype is None: |
| 114 | + weight_dtype = comfy.utils.weight_dtype(controlnet_data) |
| 115 | + |
| 116 | + if supported_inference_dtypes is None: |
| 117 | + supported_inference_dtypes = [comfy.model_management.unet_dtype()] |
| 118 | + |
| 119 | + if weight_dtype is not None: |
| 120 | + supported_inference_dtypes.append(weight_dtype) |
| 121 | + |
| 122 | + unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes) |
| 123 | + |
| 124 | + load_device = comfy.model_management.get_torch_device() |
| 125 | + |
| 126 | + manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) |
| 127 | + operations = model_options.get("custom_operations", None) |
| 128 | + if operations is None: |
| 129 | + operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype) |
| 130 | + |
| 131 | + controlnet_config["operations"] = operations |
| 132 | + controlnet_config["dtype"] = unet_dtype |
| 133 | + controlnet_config["device"] = comfy.model_management.unet_offload_device() |
| 134 | + controlnet_config.pop("out_channels") |
| 135 | + controlnet_config["hint_channels"] = 3 |
| 136 | + #controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] |
| 137 | + control_model = ControlNetCtrLoRA(**controlnet_config) |
| 138 | + |
| 139 | + if pth: |
| 140 | + if 'difference' in controlnet_data: |
| 141 | + if model is not None: |
| 142 | + comfy.model_management.load_models_gpu([model]) |
| 143 | + model_sd = model.model_state_dict() |
| 144 | + for x in controlnet_data: |
| 145 | + c_m = "control_model." |
| 146 | + if x.startswith(c_m): |
| 147 | + sd_key = "diffusion_model.{}".format(x[len(c_m):]) |
| 148 | + if sd_key in model_sd: |
| 149 | + cd = controlnet_data[x] |
| 150 | + cd += model_sd[sd_key].type(cd.dtype).to(cd.device) |
| 151 | + else: |
| 152 | + logger.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") |
| 153 | + |
| 154 | + class WeightsLoader(torch.nn.Module): |
| 155 | + pass |
| 156 | + w = WeightsLoader() |
| 157 | + w.control_model = control_model |
| 158 | + missing, unexpected = w.load_state_dict(controlnet_data, strict=False) |
| 159 | + else: |
| 160 | + missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) |
| 161 | + |
| 162 | + if len(missing) > 0: |
| 163 | + logger.warning("missing controlnet keys: {}".format(missing)) |
| 164 | + |
| 165 | + if len(unexpected) > 0: |
| 166 | + logger.debug("unexpected controlnet keys: {}".format(unexpected)) |
| 167 | + |
| 168 | + global_average_pooling = model_options.get("global_average_pooling", False) |
| 169 | + control = CtrLoRAAdvanced(control_model, timestep_keyframe, global_average_pooling=global_average_pooling, |
| 170 | + load_device=load_device, manual_cast_dtype=manual_cast_dtype) |
| 171 | + # load lora data onto the controlnet |
| 172 | + if lora_path is not None: |
| 173 | + load_lora_data(control, lora_path) |
| 174 | + |
| 175 | + return control |
| 176 | + |
| 177 | + |
| 178 | +def load_lora_data(control: CtrLoRAAdvanced, lora_path: str, loaded_data: dict[str, Tensor]=None, lora_strength=1.0): |
| 179 | + if loaded_data is None: |
| 180 | + loaded_data = comfy.utils.load_torch_file(lora_path, safe_load=True) |
| 181 | + # check that lora_data contains keys with lora_layer |
| 182 | + contains_lora_layers = False |
| 183 | + for key in loaded_data: |
| 184 | + if "lora_layer" in key: |
| 185 | + contains_lora_layers = True |
| 186 | + if not contains_lora_layers: |
| 187 | + raise Exception(f"File '{lora_path}' is not a valid CtrLoRA lora model; does not contain any lora_layer keys.") |
| 188 | + |
| 189 | + # now that we know we have a ctrlora file, separate keys into 'set' and 'lora' keys |
| 190 | + data_set: dict[str, Tensor] = {} |
| 191 | + data_lora: dict[str, Tensor] = {} |
| 192 | + |
| 193 | + for key in list(loaded_data.keys()): |
| 194 | + if 'lora_layer' in key: |
| 195 | + data_lora[key] = loaded_data.pop(key) |
| 196 | + else: |
| 197 | + data_set[key] = loaded_data.pop(key) |
| 198 | + # no keys should be left over |
| 199 | + if len(loaded_data) > 0: |
| 200 | + logger.warning("Not all keys from CtrlLoRA lora model's loaded data were parsed!") |
| 201 | + |
| 202 | + # turn set/lora data into corresponding patches; |
| 203 | + patches = {} |
| 204 | + # set will replace the values |
| 205 | + for key, value in data_set.items(): |
| 206 | + # prase model key from key; |
| 207 | + # remove "control_model." |
| 208 | + model_key = key.replace("control_model.", "") |
| 209 | + patches[model_key] = ("set", (value,)) |
| 210 | + # lora will do mm of up and down tensors |
| 211 | + for down_key in data_lora: |
| 212 | + # only process lora down keys; we will process both up+down at the same time |
| 213 | + if ".up." in key: |
| 214 | + continue |
| 215 | + # get up version of down key |
| 216 | + up_key = down_key.replace(".down.", ".up.") |
| 217 | + # get key that will match up with model key; |
| 218 | + # remove "lora_layer.down." and "control_model." |
| 219 | + model_key = down_key.replace("lora_layer.down.", "").replace("control_model.", "") |
| 220 | + |
| 221 | + weight_down = data_lora[down_key] |
| 222 | + weight_up = data_lora[up_key] |
| 223 | + # currently, ComfyUI expects 6 elements in 'lora' type, but for future-proofing add a bunch more with None |
| 224 | + patches[model_key] = ("lora", (weight_up, weight_down, None, None, None, None, |
| 225 | + None, None, None, None, None, None, None, None)) |
| 226 | + |
| 227 | + # now that patches are made, add them to model |
| 228 | + control.control_model_wrapped.add_patches(patches, strength_patch=lora_strength) |
0 commit comments