-
Notifications
You must be signed in to change notification settings - Fork 641
Added submatrices support for the SCE method #651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| # Copyright (C) 2025 Arcee AI | ||
| # SPDX-License-Identifier: LGPL-3.0-only | ||
|
|
||
| import logging | ||
| from typing import List, Optional | ||
|
|
||
| import torch | ||
|
|
@@ -24,8 +25,31 @@ def sce_merge( | |
| ) -> torch.Tensor: | ||
| if not tensors: | ||
| return base_tensor | ||
|
|
||
| mask_dtype = torch.int8 if int8_mask else base_tensor.dtype | ||
| task_vectors = torch.stack([t - base_tensor for t in tensors], dim=0) | ||
|
|
||
| # Process tensors to handle shape mismatches | ||
| valid_task_vectors = [] | ||
|
|
||
| for idx, t in enumerate(tensors): | ||
| # Convert to base dtype | ||
| t = t.to(base_tensor.dtype) | ||
|
|
||
| # Handle shape mismatch - resize to base dimensions | ||
| if t.shape != base_tensor.shape: | ||
| # Slice tensor to match base_tensor dimensions | ||
| t = t[: base_tensor.shape[0], : base_tensor.shape[1]] | ||
| logging.warning(f"Using submatrix of tensor {idx}") | ||
|
|
||
| # Compute task vector (delta) | ||
| task_vector = t - base_tensor | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Slicing fails when tensor is smaller than baseThe submatrix slicing |
||
| valid_task_vectors.append(task_vector) | ||
|
|
||
| # If no valid tensors remain, return base | ||
| if not valid_task_vectors: | ||
| return base_tensor | ||
|
|
||
| task_vectors = torch.stack(valid_task_vectors, dim=0) | ||
|
|
||
| if select_topk < 1: | ||
| mask = sce_mask(task_vectors, select_topk, mask_dtype) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Submatrix slicing crashes on 1D tensors
The submatrix slicing
t[: base_tensor.shape[0], : base_tensor.shape[1]]assumes tensors are 2D. If a 1D tensor (like a bias vector or layer norm weight) has a shape mismatch, accessingbase_tensor.shape[1]will raise anIndexError. The reference implementation ingeneralized_task_arithmetic.pyavoids this by only applying 2D slicing tois_embedweights, which are guaranteed to be 2D embedding matrices. The SCE implementation lacks this guard and applies 2D indexing unconditionally to any shape-mismatched tensor.