Skip to content

Commit 6e4aeb2

Browse files
authored
Fix channel padding with new Comfy core API (#106)
* Fix channel padding with new Comfy core API * nit
1 parent 2cbfe39 commit 6e4aeb2

File tree

1 file changed

+12
-69
lines changed

1 file changed

+12
-69
lines changed

layered_diffusion.py

Lines changed: 12 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
from enum import Enum
33
import torch
4-
import functools
54
import copy
65
from typing import Optional, List
76
from dataclasses import dataclass
@@ -31,73 +30,6 @@
3130
load_layer_model_state_dict = load_torch_file
3231

3332

34-
# ------------ Start patching ComfyUI ------------
35-
def calculate_weight_adjust_channel(func):
36-
"""Patches ComfyUI's LoRA weight application to accept multi-channel inputs."""
37-
38-
@functools.wraps(func)
39-
def calculate_weight(
40-
self: ModelPatcher, patches, weight: torch.Tensor, key: str
41-
) -> torch.Tensor:
42-
weight = func(self, patches, weight, key)
43-
44-
for p in patches:
45-
alpha = p[0]
46-
v = p[1]
47-
48-
# The recursion call should be handled in the main func call.
49-
if isinstance(v, list):
50-
continue
51-
52-
if len(v) == 1:
53-
patch_type = "diff"
54-
elif len(v) == 2:
55-
patch_type = v[0]
56-
v = v[1]
57-
58-
if patch_type == "diff":
59-
w1 = v[0]
60-
if all(
61-
(
62-
alpha != 0.0,
63-
w1.shape != weight.shape,
64-
w1.ndim == weight.ndim == 4,
65-
)
66-
):
67-
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
68-
print(
69-
f"Merged with {key} channel changed from {weight.shape} to {new_shape}"
70-
)
71-
new_diff = alpha * comfy.model_management.cast_to_device(
72-
w1, weight.device, weight.dtype
73-
)
74-
new_weight = torch.zeros(size=new_shape).to(weight)
75-
new_weight[
76-
: weight.shape[0],
77-
: weight.shape[1],
78-
: weight.shape[2],
79-
: weight.shape[3],
80-
] = weight
81-
new_weight[
82-
: new_diff.shape[0],
83-
: new_diff.shape[1],
84-
: new_diff.shape[2],
85-
: new_diff.shape[3],
86-
] += new_diff
87-
new_weight = new_weight.contiguous().clone()
88-
weight = new_weight
89-
return weight
90-
91-
return calculate_weight
92-
93-
94-
ModelPatcher.calculate_weight = calculate_weight_adjust_channel(
95-
ModelPatcher.calculate_weight
96-
)
97-
98-
# ------------ End patching ComfyUI ------------
99-
100-
10133
class LayeredDiffusionDecode:
10234
"""
10335
Decode alpha channel value from pixel value.
@@ -323,8 +255,19 @@ def apply_layered_diffusion(
323255
model_dir=layer_model_root,
324256
file_name=self.model_file_name,
325257
)
258+
def pad_diff_weight(v):
259+
if len(v) == 1:
260+
return ("diff", [v[0], {"pad_weight": True}])
261+
elif len(v) == 2 and v[0] == "diff":
262+
return ("diff", [v[1][0], {"pad_weight": True}])
263+
else:
264+
return v
265+
326266
layer_lora_state_dict = load_layer_model_state_dict(model_path)
327-
layer_lora_patch_dict = to_lora_patch_dict(layer_lora_state_dict)
267+
layer_lora_patch_dict = {
268+
k: pad_diff_weight(v)
269+
for k, v in to_lora_patch_dict(layer_lora_state_dict).items()
270+
}
328271
work_model = model.clone()
329272
work_model.add_patches(layer_lora_patch_dict, weight)
330273
return (work_model,)

0 commit comments

Comments
 (0)