Skip to content

Commit 6bbdcd2

Browse files
authored
Support weight padding on diff weight patch (comfyanonymous#4576)
1 parent ab13000 commit 6bbdcd2

File tree

1 file changed

+44
-4
lines changed

1 file changed

+44
-4
lines changed

comfy/lora.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
along with this program. If not, see <https://www.gnu.org/licenses/>.
1717
"""
1818

19+
from __future__ import annotations
1920
import comfy.utils
2021
import comfy.model_management
2122
import comfy.model_base
@@ -347,6 +348,39 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat
347348
weight[:] = weight_calc
348349
return weight
349350

351+
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
352+
"""
353+
Pad a tensor to a new shape with zeros.
354+
355+
Args:
356+
tensor (torch.Tensor): The original tensor to be padded.
357+
new_shape (List[int]): The desired shape of the padded tensor.
358+
359+
Returns:
360+
torch.Tensor: A new tensor padded with zeros to the specified shape.
361+
362+
Note:
363+
If the new shape is smaller than the original tensor in any dimension,
364+
the original tensor will be truncated in that dimension.
365+
"""
366+
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
367+
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
368+
369+
if len(new_shape) != len(tensor.shape):
370+
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
371+
372+
# Create a new tensor filled with zeros
373+
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
374+
375+
# Create slicing tuples for both tensors
376+
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
377+
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
378+
379+
# Copy the original tensor into the new tensor
380+
padded_tensor[new_slices] = tensor[orig_slices]
381+
382+
return padded_tensor
383+
350384
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
351385
for p in patches:
352386
strength = p[0]
@@ -375,12 +409,18 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
375409
v = v[1]
376410

377411
if patch_type == "diff":
378-
w1 = v[0]
412+
diff: torch.Tensor = v[0]
413+
# An extra flag to pad the weight if the diff's shape is larger than the weight
414+
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
415+
if do_pad_weight and diff.shape != weight.shape:
416+
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
417+
weight = pad_tensor_to_shape(weight, diff.shape)
418+
379419
if strength != 0.0:
380-
if w1.shape != weight.shape:
381-
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
420+
if diff.shape != weight.shape:
421+
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
382422
else:
383-
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
423+
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
384424
elif patch_type == "lora": #lora/locon
385425
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
386426
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)

0 commit comments

Comments
 (0)