Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion mergekit/merge_methods/sce.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2025 Arcee AI
# SPDX-License-Identifier: LGPL-3.0-only

import logging
from typing import List, Optional

import torch
Expand All @@ -24,8 +25,31 @@ def sce_merge(
) -> torch.Tensor:
if not tensors:
return base_tensor

mask_dtype = torch.int8 if int8_mask else base_tensor.dtype
task_vectors = torch.stack([t - base_tensor for t in tensors], dim=0)

# Process tensors to handle shape mismatches
valid_task_vectors = []

for idx, t in enumerate(tensors):
# Convert to base dtype
t = t.to(base_tensor.dtype)

# Handle shape mismatch - resize to base dimensions
if t.shape != base_tensor.shape:
# Slice tensor to match base_tensor dimensions
t = t[: base_tensor.shape[0], : base_tensor.shape[1]]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Submatrix slicing crashes on 1D tensors

The submatrix slicing t[: base_tensor.shape[0], : base_tensor.shape[1]] assumes tensors are 2D. If a 1D tensor (like a bias vector or layer norm weight) has a shape mismatch, accessing base_tensor.shape[1] will raise an IndexError. The reference implementation in generalized_task_arithmetic.py avoids this by only applying 2D slicing to is_embed weights, which are guaranteed to be 2D embedding matrices. The SCE implementation lacks this guard and applies 2D indexing unconditionally to any shape-mismatched tensor.

Fix in Cursor Fix in Web

logging.warning(f"Using submatrix of tensor {idx}")

# Compute task vector (delta)
task_vector = t - base_tensor
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Slicing fails when tensor is smaller than base

The submatrix slicing t[: base_tensor.shape[0], : base_tensor.shape[1]] only works when t is larger than base_tensor. If t is smaller in any dimension (e.g., a model with smaller vocabulary), the slice operation returns t unchanged, and the subsequent subtraction t - base_tensor on line 45 will raise a broadcasting error due to shape mismatch. The reference implementation in generalized_task_arithmetic.py handles this by skipping non-embedding tensors with mismatches entirely, but the SCE implementation unconditionally attempts to proceed.

Fix in Cursor Fix in Web

valid_task_vectors.append(task_vector)

# If no valid tensors remain, return base
if not valid_task_vectors:
return base_tensor

task_vectors = torch.stack(valid_task_vectors, dim=0)

if select_topk < 1:
mask = sce_mask(task_vectors, select_topk, mask_dtype)
Expand Down
Loading