Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ For detailed explanations, parameter descriptions, and use cases for each method
| [**Multi-SLERP** (`multislerp`)](docs/merge_methods.md#multi-slerp-multislerp) | Barycentric SLERP for multiple models. | ≥2 | * | Spherical interpolation for >2 models. |
| [**Karcher Mean** (`karcher`)](docs/merge_methods.md#karcher-mean-karcher) | Riemannian barycenter of model parameters. | ≥2 | - | Geometrically sound averaging on manifolds. |
| [**Task Arithmetic** (`task_arithmetic`)](docs/merge_methods.md#task-arithmetic-task_arithmetic) | Linearly combine "task vectors" (differences from a base). | ≥2 | ✓ | Transferring/combining fine-tuned skills. |
| [**Core Space** (`core_space`)](docs/merge_methods.md#core-space-core_space) | SVD-aligned LoRA merging in compact core subspace. | ≥2 | ✓ | Efficient LoRA merging, heterogeneous ranks, subspace alignment.|
| [**TIES** (`ties`)](docs/merge_methods.md#ties-merging-ties) | Task arithmetic + sparsification & sign consensus. | ≥2 | ✓ | Merging many models, reducing interference. |
| [**DARE** (`dare_linear`, `dare_ties`)](docs/merge_methods.md#dare-dare_linear-dare_ties) | Task arithmetic + random pruning & rescaling. | ≥2 | ✓ | Robust skill retention, similar to TIES. |
| [**DELLA** (`della`, `della_linear`)](docs/merge_methods.md#della-della-della_linear) | Task arithmetic + adaptive magnitude-based pruning. | ≥2 | ✓ | Prioritizing important changes, reducing interference. |
Expand Down
41 changes: 41 additions & 0 deletions docs/merge_methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- [Karcher Mean (`karcher`)](#karcher-mean-karcher)
- [Task Vector Methods](#task-vector-methods)
- [Task Arithmetic (`task_arithmetic`)](#task-arithmetic-task_arithmetic)
- [Core Space (`core_space`)](#core-space)
- [TIES-Merging (`ties`)](#ties-merging-ties)
- [DARE (`dare_linear`, `dare_ties`)](#dare-dare_linear-dare_ties)
- [DELLA (`della`, `della_linear`)](#della-della-della_linear)
Expand Down Expand Up @@ -149,6 +150,46 @@ This guide provides detailed information about the various model merging algorit

**Reference:** [Editing Models with Task Arithmetic](https://arxiv.org/abs/2212.04089)

### Core Space (`core_space`)

**Concept**: Merges LoRA-adapted models by projecting them into a shared, aligned core space using SVD-based reference bases. Operates in a compact subspace for efficiency while preserving information.

**Algorithm**:

1. Extract LoRA matrices (B, A) from each model where ΔW = B @ A
2. Compute reference bases via SVD: concatenate all B matrices horizontally and A matrices vertically, then compute orthonormal bases U_B and V_A
3. Project to core space: Core_i = U_B^T @ B_i @ A_i @ V_A
4. Merge in core space using weighted average
5. Reconstruct: ΔW_merged = U_B @ Core_merged @ V_A^T, then W_final = W_base + ΔW_merged

**Inputs**: Requires 2 or more models, plus one `base_model`.

**Parameters**:

- `weight` (per-model, float, default: 1.0): Weight for each model. Currently uses equal weights.

**Use Cases**:

- Efficiently merging multiple LoRA adapters
- Multi-task model creation from specialized adapters
- When adapters have different ranks
- Resource-constrained environments

**Example**:

```yaml
models:
- model: meta-llama/Llama-2-7b-hf
- model: username/llama2-lora-math
- model: username/llama2-lora-code

merge_method: core_space
base_model: meta-llama/Llama-2-7b-hf
dtype: bfloat16
```

**Reference**: [Accurate and Efficient Low-Rank Model Merging in Core Space](https://arxiv.org/abs/2509.17786) (Panariello et al., NeurIPS 2025)

### TIES-Merging (`ties`)

**Concept:** Builds on Task Arithmetic by sparsifying task vectors and applying a sign consensus algorithm. This helps to resolve interference when merging multiple models and retain more of their individual strengths.
Expand Down
11 changes: 11 additions & 0 deletions examples/core_space.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
models:
- model: gpt2
parameters:
weight: 0.5
- model: gpt2
parameters:
weight: 1.0

merge_method: core_space
base_model: gpt2
dtype: float32
284 changes: 284 additions & 0 deletions mergekit/merge_methods/core_space.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
"""
Core Space Merging Method for mergekit
Based on "Accurate and Efficient Low-Rank Model Merging in Core Space"
(Panariello et al., NeurIPS 2025)
"""

import logging
from typing import Any, Dict, List, Optional

import torch
from typing_extensions import override

from mergekit.architecture import WeightInfo
from mergekit.common import ModelReference
from mergekit.graph import Task
from mergekit.merge_methods.base import (
ConfigParameterDef,
MergeMethod,
MergeTensorInput,
)

log = logging.getLogger(__name__)


class CoreSpaceTask(Task[torch.Tensor]):
"""Task for performing core space merge on a single tensor."""

gather_tensors: MergeTensorInput
base_model: ModelReference
weight_info: WeightInfo
default_weight: float

def uses_accelerator(self) -> bool:
return True

def arguments(self) -> Dict[str, Task]:
return {"tensors": self.gather_tensors}

def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor:
"""
Execute core space merge for a single tensor.

Args:
tensors: Dictionary mapping model references to their tensors

Returns:
Merged tensor
"""
if len(tensors) == 1:
return list(tensors.values())[0]

# Get base model tensor
base_tensor = tensors.get(self.base_model)
if base_tensor is None:
log.warning("Base model not found, using first model as base")
self.base_model = list(tensors.keys())[0]
base_tensor = tensors[self.base_model]

# Check if this is a LoRA-adapted weight
# LoRA weights typically have "lora_A" or "lora_B" in their names
is_lora = self._is_lora_weight(self.weight_info.name)

if not is_lora:
# For non-LoRA weights, fall back to weighted average
log.debug(
f"Using weighted average for non-LoRA weight: {self.weight_info.name}"
)
return self._weighted_average(tensors, base_tensor)

# Perform core space merge for LoRA weights
try:
return self._core_space_merge(tensors, base_tensor)
except Exception as e:
log.warning(f"Core space merge failed for {self.weight_info.name}: {e}")
log.warning("Falling back to weighted average")
return self._weighted_average(tensors, base_tensor)

def _is_lora_weight(self, weight_name: str) -> bool:
"""Check if a weight is LoRA-adapted."""
lora_indicators = ["lora_A", "lora_B", "lora_", "adapter"]
return any(indicator in weight_name for indicator in lora_indicators)

def _extract_lora_matrices(
self, tensors: Dict[ModelReference, torch.Tensor], base_tensor: torch.Tensor
) -> tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
Extract LoRA A and B matrices from tensors.

For actual LoRA adapters, we need to separate A and B matrices.
For full fine-tuned models, we compute task vectors and approximate
them as low-rank using SVD.
"""
lora_As = []
lora_Bs = []

for model_ref, tensor in tensors.items():
if model_ref == self.base_model:
continue

# Compute task vector (delta from base)
delta = tensor - base_tensor

# Check if this is already a LoRA matrix
if "lora_A" in self.weight_info.name:
# This is already the A matrix
lora_As.append(delta)
# We'll need to match with corresponding B matrix
# For now, create identity-like B
lora_Bs.append(torch.eye(delta.shape[0], device=delta.device))
elif "lora_B" in self.weight_info.name:
# This is already the B matrix
lora_Bs.append(delta)
# Create identity-like A
lora_As.append(torch.eye(delta.shape[1], device=delta.device))
else:
# Full weight - approximate as low-rank via SVD
# ΔW ≈ B @ A where rank is chosen automatically
rank = min(16, min(delta.shape) // 4) # Adaptive rank
U, S, Vt = torch.linalg.svd(delta, full_matrices=False)

# Keep top-rank components
A = torch.diag(S[:rank]) @ Vt[:rank, :]
B = U[:, :rank]

lora_As.append(A)
lora_Bs.append(B)

return lora_As, lora_Bs

def _core_space_merge(
self, tensors: Dict[ModelReference, torch.Tensor], base_tensor: torch.Tensor
) -> torch.Tensor:
"""
Perform core space merge.

Steps:
1. Extract LoRA A and B matrices
2. Compute reference bases via SVD
3. Project to core space
4. Merge in core space
5. Reconstruct to full space
"""
# Extract LoRA matrices
lora_As, lora_Bs = self._extract_lora_matrices(tensors, base_tensor)

if len(lora_As) == 0:
return base_tensor

# Compute reference bases
U_B, V_A = self._compute_reference_bases(lora_Bs, lora_As)

# Determine common rank for projection
# After concatenation, U_B and V_A may have different second dimensions
common_rank = min(U_B.shape[1], V_A.shape[1])
U_B_trunc = U_B[:, :common_rank]
V_A_trunc = V_A[:, :common_rank]

# Project each LoRA to core space
core_reprs = []
model_refs = [ref for ref in tensors.keys() if ref != self.base_model]

for A, B in zip(lora_As, lora_Bs):
core_repr = U_B_trunc.T @ B @ A @ V_A_trunc
core_reprs.append(core_repr)

# Merge in core space using equal weights (or default_weight)
# For simplicity, use equal weights for all models
num_models = len(core_reprs)
core_merged = sum(core_reprs) / num_models

# Reconstruct to full space
delta_W = U_B_trunc @ core_merged @ V_A_trunc.T

# Add to base model
merged = base_tensor + delta_W

return merged

def _compute_reference_bases(
self, B_matrices: List[torch.Tensor], A_matrices: List[torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute reference bases U_B and V_A using SVD."""
# Concatenate in the subspace dimension (not stacking!)
# B matrices: (d_out, rank) each -> concatenate horizontally
B_concat = torch.cat(B_matrices, dim=1) # (d_out, num_models*rank)

# A matrices: (rank, d_in) each -> concatenate vertically
A_concat = torch.cat(A_matrices, dim=0) # (num_models*rank, d_in)

# Compute SVD
U_B, _, _ = torch.linalg.svd(B_concat, full_matrices=False)
_, _, V_A_T = torch.linalg.svd(A_concat, full_matrices=False)
V_A = V_A_T.T

return U_B, V_A

def _weighted_average(
self, tensors: Dict[ModelReference, torch.Tensor], base_tensor: torch.Tensor
) -> torch.Tensor:
"""Fall back to simple weighted average."""
# For now, use equal weights (simple average)
result = torch.zeros_like(base_tensor)

for model_ref, tensor in tensors.items():
result += tensor

return result / len(tensors) if len(tensors) > 0 else base_tensor

def group_label(self) -> Optional[str]:
return self.gather_tensors.group_label()


class CoreSpaceMerge(MergeMethod):
"""
Core Space merging method for LoRA adapters.

This method merges LoRA-adapted models by:
1. Projecting them into a shared core space using SVD-based reference bases
2. Merging in the compact core space
3. Reconstructing back to full parameter space

Benefits:
- Efficient: Operates in compact core space
- Aligned: SVD-based alignment of LoRA subspaces
- Information-preserving: No loss of information in projection
- Flexible: Supports heterogeneous ranks
"""

def name(self) -> str:
return "core_space"

@override
def pretty_name(self) -> Optional[str]:
return "Core Space Merge"

@override
def reference_url(self) -> Optional[str]:
return "https://github.com/apanariello4/core-space-merging"

def parameters(self) -> List[ConfigParameterDef]:
return [
ConfigParameterDef(name="weight", required=False, default_value=1.0),
]

def make_task(
self,
*,
output_weight: WeightInfo,
tensors: MergeTensorInput,
base_model: Optional[ModelReference],
parameters: Dict[str, Any],
**kwargs,
) -> Task:
"""
Create a task for core space merging.

Args:
output_weight: Information about the output weight
tensors: Input tensors from different models
base_model: Base model reference
parameters: Merge parameters (weights, etc.)
**kwargs: Additional arguments

Returns:
CoreSpaceTask to execute the merge
"""
# Get weight parameter - handle ImmutableMap
weight_param = parameters["weight"] if "weight" in parameters else 1.0

# Convert to float for hashability
default_weight = (
float(weight_param) if not isinstance(weight_param, dict) else 1.0
)

return CoreSpaceTask(
gather_tensors=tensors,
base_model=base_model,
weight_info=output_weight,
default_weight=default_weight,
)


# For registration in mergekit's method registry
__all__ = ["CoreSpaceMerge"]
2 changes: 2 additions & 0 deletions mergekit/merge_methods/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from mergekit.merge_methods.arcee_fusion import ArceeFusionMerge
from mergekit.merge_methods.base import MergeMethod
from mergekit.merge_methods.core_space import CoreSpaceMerge
from mergekit.merge_methods.generalized_task_arithmetic import (
ConsensusMethod,
GeneralizedTaskArithmeticMerge,
Expand All @@ -25,6 +26,7 @@
ModelStockMerge(),
ArceeFusionMerge(),
KarcherMerge(),
CoreSpaceMerge(),
# generalized task arithmetic methods
GeneralizedTaskArithmeticMerge(
consensus_method=None,
Expand Down
Loading