|
1 | 1 | import os |
2 | 2 | from enum import Enum |
3 | 3 | import torch |
4 | | -import functools |
5 | 4 | import copy |
6 | 5 | from typing import Optional, List |
7 | 6 | from dataclasses import dataclass |
|
31 | 30 | load_layer_model_state_dict = load_torch_file |
32 | 31 |
|
33 | 32 |
|
34 | | -# ------------ Start patching ComfyUI ------------ |
35 | | -def calculate_weight_adjust_channel(func): |
36 | | - """Patches ComfyUI's LoRA weight application to accept multi-channel inputs.""" |
37 | | - |
38 | | - @functools.wraps(func) |
39 | | - def calculate_weight( |
40 | | - self: ModelPatcher, patches, weight: torch.Tensor, key: str |
41 | | - ) -> torch.Tensor: |
42 | | - weight = func(self, patches, weight, key) |
43 | | - |
44 | | - for p in patches: |
45 | | - alpha = p[0] |
46 | | - v = p[1] |
47 | | - |
48 | | - # The recursion call should be handled in the main func call. |
49 | | - if isinstance(v, list): |
50 | | - continue |
51 | | - |
52 | | - if len(v) == 1: |
53 | | - patch_type = "diff" |
54 | | - elif len(v) == 2: |
55 | | - patch_type = v[0] |
56 | | - v = v[1] |
57 | | - |
58 | | - if patch_type == "diff": |
59 | | - w1 = v[0] |
60 | | - if all( |
61 | | - ( |
62 | | - alpha != 0.0, |
63 | | - w1.shape != weight.shape, |
64 | | - w1.ndim == weight.ndim == 4, |
65 | | - ) |
66 | | - ): |
67 | | - new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)] |
68 | | - print( |
69 | | - f"Merged with {key} channel changed from {weight.shape} to {new_shape}" |
70 | | - ) |
71 | | - new_diff = alpha * comfy.model_management.cast_to_device( |
72 | | - w1, weight.device, weight.dtype |
73 | | - ) |
74 | | - new_weight = torch.zeros(size=new_shape).to(weight) |
75 | | - new_weight[ |
76 | | - : weight.shape[0], |
77 | | - : weight.shape[1], |
78 | | - : weight.shape[2], |
79 | | - : weight.shape[3], |
80 | | - ] = weight |
81 | | - new_weight[ |
82 | | - : new_diff.shape[0], |
83 | | - : new_diff.shape[1], |
84 | | - : new_diff.shape[2], |
85 | | - : new_diff.shape[3], |
86 | | - ] += new_diff |
87 | | - new_weight = new_weight.contiguous().clone() |
88 | | - weight = new_weight |
89 | | - return weight |
90 | | - |
91 | | - return calculate_weight |
92 | | - |
93 | | - |
94 | | -ModelPatcher.calculate_weight = calculate_weight_adjust_channel( |
95 | | - ModelPatcher.calculate_weight |
96 | | -) |
97 | | - |
98 | | -# ------------ End patching ComfyUI ------------ |
99 | | - |
100 | | - |
101 | 33 | class LayeredDiffusionDecode: |
102 | 34 | """ |
103 | 35 | Decode alpha channel value from pixel value. |
@@ -323,8 +255,19 @@ def apply_layered_diffusion( |
323 | 255 | model_dir=layer_model_root, |
324 | 256 | file_name=self.model_file_name, |
325 | 257 | ) |
| 258 | + def pad_diff_weight(v): |
| 259 | + if len(v) == 1: |
| 260 | + return ("diff", [v[0], {"pad_weight": True}]) |
| 261 | + elif len(v) == 2 and v[0] == "diff": |
| 262 | + return ("diff", [v[1][0], {"pad_weight": True}]) |
| 263 | + else: |
| 264 | + return v |
| 265 | + |
326 | 266 | layer_lora_state_dict = load_layer_model_state_dict(model_path) |
327 | | - layer_lora_patch_dict = to_lora_patch_dict(layer_lora_state_dict) |
| 267 | + layer_lora_patch_dict = { |
| 268 | + k: pad_diff_weight(v) |
| 269 | + for k, v in to_lora_patch_dict(layer_lora_state_dict).items() |
| 270 | + } |
328 | 271 | work_model = model.clone() |
329 | 272 | work_model.add_patches(layer_lora_patch_dict, weight) |
330 | 273 | return (work_model,) |
|
0 commit comments