diff --git a/mergekit/merge_methods/sce.py b/mergekit/merge_methods/sce.py index bb971717..5d7f14b8 100644 --- a/mergekit/merge_methods/sce.py +++ b/mergekit/merge_methods/sce.py @@ -1,6 +1,7 @@ # Copyright (C) 2025 Arcee AI # SPDX-License-Identifier: LGPL-3.0-only +import logging from typing import List, Optional import torch @@ -24,8 +25,31 @@ def sce_merge( ) -> torch.Tensor: if not tensors: return base_tensor + mask_dtype = torch.int8 if int8_mask else base_tensor.dtype - task_vectors = torch.stack([t - base_tensor for t in tensors], dim=0) + + # Process tensors to handle shape mismatches + valid_task_vectors = [] + + for idx, t in enumerate(tensors): + # Convert to base dtype + t = t.to(base_tensor.dtype) + + # Handle shape mismatch - resize to base dimensions + if t.shape != base_tensor.shape: + # Slice tensor to match base_tensor dimensions + t = t[: base_tensor.shape[0], : base_tensor.shape[1]] + logging.warning(f"Using submatrix of tensor {idx}") + + # Compute task vector (delta) + task_vector = t - base_tensor + valid_task_vectors.append(task_vector) + + # If no valid tensors remain, return base + if not valid_task_vectors: + return base_tensor + + task_vectors = torch.stack(valid_task_vectors, dim=0) if select_topk < 1: mask = sce_mask(task_vectors, select_topk, mask_dtype)