|
16 | 16 | along with this program. If not, see <https://www.gnu.org/licenses/>. |
17 | 17 | """ |
18 | 18 |
|
| 19 | +from __future__ import annotations |
19 | 20 | import comfy.utils |
20 | 21 | import comfy.model_management |
21 | 22 | import comfy.model_base |
@@ -347,6 +348,39 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat |
347 | 348 | weight[:] = weight_calc |
348 | 349 | return weight |
349 | 350 |
|
| 351 | +def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor: |
| 352 | + """ |
| 353 | + Pad a tensor to a new shape with zeros. |
| 354 | +
|
| 355 | + Args: |
| 356 | + tensor (torch.Tensor): The original tensor to be padded. |
| 357 | + new_shape (List[int]): The desired shape of the padded tensor. |
| 358 | +
|
| 359 | + Returns: |
| 360 | + torch.Tensor: A new tensor padded with zeros to the specified shape. |
| 361 | +
|
| 362 | + Note: |
| 363 | + If the new shape is smaller than the original tensor in any dimension, |
| 364 | + the original tensor will be truncated in that dimension. |
| 365 | + """ |
| 366 | + if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]): |
| 367 | + raise ValueError("The new shape must be larger than the original tensor in all dimensions") |
| 368 | + |
| 369 | + if len(new_shape) != len(tensor.shape): |
| 370 | + raise ValueError("The new shape must have the same number of dimensions as the original tensor") |
| 371 | + |
| 372 | + # Create a new tensor filled with zeros |
| 373 | + padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) |
| 374 | + |
| 375 | + # Create slicing tuples for both tensors |
| 376 | + orig_slices = tuple(slice(0, dim) for dim in tensor.shape) |
| 377 | + new_slices = tuple(slice(0, dim) for dim in tensor.shape) |
| 378 | + |
| 379 | + # Copy the original tensor into the new tensor |
| 380 | + padded_tensor[new_slices] = tensor[orig_slices] |
| 381 | + |
| 382 | + return padded_tensor |
| 383 | + |
350 | 384 | def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): |
351 | 385 | for p in patches: |
352 | 386 | strength = p[0] |
@@ -375,12 +409,18 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): |
375 | 409 | v = v[1] |
376 | 410 |
|
377 | 411 | if patch_type == "diff": |
378 | | - w1 = v[0] |
| 412 | + diff: torch.Tensor = v[0] |
| 413 | + # An extra flag to pad the weight if the diff's shape is larger than the weight |
| 414 | + do_pad_weight = len(v) > 1 and v[1]['pad_weight'] |
| 415 | + if do_pad_weight and diff.shape != weight.shape: |
| 416 | + logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape)) |
| 417 | + weight = pad_tensor_to_shape(weight, diff.shape) |
| 418 | + |
379 | 419 | if strength != 0.0: |
380 | | - if w1.shape != weight.shape: |
381 | | - logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) |
| 420 | + if diff.shape != weight.shape: |
| 421 | + logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape)) |
382 | 422 | else: |
383 | | - weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)) |
| 423 | + weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype)) |
384 | 424 | elif patch_type == "lora": #lora/locon |
385 | 425 | mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype) |
386 | 426 | mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype) |
|
0 commit comments