From 328f4c6b353a0e67a5c49f28350b935a197087f2 Mon Sep 17 00:00:00 2001 From: kaganhitit11 Date: Thu, 18 Sep 2025 17:31:39 +0200 Subject: [PATCH 01/14] added tsvm, iso-c, ta+sb, ties+sb --- .../generalized_task_arithmetic.py | 14 +- mergekit/merge_methods/registry.py | 36 +++ mergekit/subspace_helpers.py | 207 ++++++++++++++++++ 3 files changed, 252 insertions(+), 5 deletions(-) create mode 100644 mergekit/subspace_helpers.py diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index 9220343b..7af0c38e 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -18,7 +18,7 @@ MergeTensorInput, ) from mergekit.sparsify import RescaleNorm, SparsificationMethod, sparsify - +from mergekit.subspace_helpers import iso_c, compute_and_sum_svd_mem_reduction class ConsensusMethod(str, Enum): count = "count" @@ -130,7 +130,6 @@ def execute( ) if not tvs: return base - # sparsify if self.method.sparsification_method: for tv_info in tvs: @@ -148,9 +147,8 @@ def execute( rescale_norm=self.rescale_norm, **kwargs, ) - deltas = torch.stack([tv["delta"] for tv in tvs], dim=0) - + weights = torch.tensor( [tv["weight"] for tv in tvs], dtype=deltas.dtype, device=deltas.device ) @@ -180,7 +178,13 @@ def execute( if self.lambda_ != 1: mixed_delta *= self.lambda_ - + + if self.method.name() == "iso_c": + mixed_delta = iso_c(deltas, deltas.device) + elif self.method.name() == "tsvm": + mixed_delta = compute_and_sum_svd_mem_reduction(deltas, deltas.device) + elif self.method.name() in ["task_arithmetic_sb", "ties_sb"]: + mixed_delta = subspace_boosting(mixed_delta) return (base + mixed_delta).to(base.dtype) def group_label(self) -> Optional[str]: diff --git a/mergekit/merge_methods/registry.py b/mergekit/merge_methods/registry.py index 7b40f4a3..0379daa0 100644 --- a/mergekit/merge_methods/registry.py +++ b/mergekit/merge_methods/registry.py @@ -98,6 +98,42 @@ method_pretty_name="Linear DELLA", method_reference_url="https://arxiv.org/abs/2406.11617", ), + GeneralizedTaskArithmeticMerge( + consensus_method=None, + sparsification_method=None, + default_normalize=False, + default_rescale=False, + method_name="tsvm", + method_pretty_name="TSV-M", + method_reference_url="https://arxiv.org/abs/2412.00081", + ), + GeneralizedTaskArithmeticMerge( + consensus_method=None, + sparsification_method=None, + default_normalize=False, + default_rescale=False, + method_name="iso_c", + method_pretty_name="ISO-C", + method_reference_url="https://www.arxiv.org/pdf/2502.04959", + ), + GeneralizedTaskArithmeticMerge( + consensus_method=None, + sparsification_method=None, + default_normalize=False, + default_rescale=False, + method_name="task_arithmetic_sb", + method_pretty_name="Task Arithmetic with Subspace Boosting", + method_reference_url="https://arxiv.org/abs/2212.04089", + ), + GeneralizedTaskArithmeticMerge( + consensus_method=ConsensusMethod.sum, + sparsification_method=SparsificationMethod.magnitude, + default_normalize=True, + default_rescale=False, + method_name="ties_sb", + method_pretty_name="TIES with Subspace Boosting", + method_reference_url="https://arxiv.org/abs/2306.01708", + ), ] REGISTERED_MERGE_METHODS: Dict[str, MergeMethod] = { diff --git a/mergekit/subspace_helpers.py b/mergekit/subspace_helpers.py new file mode 100644 index 00000000..2be2078e --- /dev/null +++ b/mergekit/subspace_helpers.py @@ -0,0 +1,207 @@ +import torch +from typing import List, Dict, Any + +def iso_c(task_vectors: List[Dict[str, Any]], device: torch.device) -> Dict[str, Any]: + print("Computing SVD...") + with torch.no_grad(): + new_vector = {} + for key in task_vectors[0]: + tvs = [task_vector[key] for task_vector in task_vectors] + new_vector[key] = sum(tvs) / len(tvs) + + if len(task_vectors[0][key].shape) == 2 and "text_projection" not in key: + new_vector[key] *= len(tvs) + U, S, V = torch.linalg.svd(new_vector[key], full_matrices=False) + S_mean = torch.ones_like(S) * S.mean() + + new_vector[key] = torch.linalg.multi_dot( + ( + U, + torch.diag(S_mean), + V, + ) + ) + + return new_vector + +############### +#### TSV Merge Orthogonalization +def compute_and_sum_svd_mem_reduction(task_vectors: List[Dict[str, Any]], device: torch.device) -> Dict[str, Any]: + """ + Computes the Singular Value Decomposition (SVD) for each vector in the task_vectors, + reduces the dimensionality of the vectors based on the sv_reduction factor, and concatenate + the low-rank matrices. If the vector is not a 2D tensor or is "text_projection", it computes the mean of the vectors. + Computation of the SVD is performed also for the second operation. + + Args: + task_vectors (list): A list of task vector objects, where each object contains a + dictionary of vectors. + Returns: + dict: A dictionary containing the new vectors after SVD computation and merging. + """ + sv_reduction = 1 / len(task_vectors) + print("Computing SVD...") + with torch.no_grad(): + new_vector = {} + for key in task_vectors[0]: + new_vector[key] = {} + for i, task_vector in enumerate(task_vectors): + vec = task_vector[key] + + if ( + len(task_vector[key].shape) == 2 + and "text_projection" not in key + ): + u, s, v = torch.linalg.svd(vec, full_matrices=False) + + if i == 0: + print(f"Computed SVD for {key}...") + sum_u = torch.zeros_like(u, device=device) + sum_s = torch.zeros_like(s, device=device) + sum_v = torch.zeros_like(v, device=device) + reduced_index_s = int(s.shape[0] * sv_reduction) + + # select only the first reduced_index_s columns of u and place them + sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[ + :, :reduced_index_s + ] + sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[ + :reduced_index_s + ] + # select only the first reduced_index_s rows of v and place them + sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[ + :reduced_index_s, : + ] + + else: + if i == 0: + new_vector[key] = vec.clone() + else: + new_vector[key] += (vec - new_vector[key]) / (i + 1) + + if len(task_vector[key].shape) == 2 and "text_projection" not in key: + u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False) + u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False) + + new_vector[key] = torch.linalg.multi_dot( + ( + u_u, + v_u, + torch.diag(sum_s), + u_v, + v_v, + ) + ) + ''' + + if config.method.apply_subspace_boosting: + print("Applying subspace boosting...") + U, S, Vh = torch.linalg.svd(new_vector[key], full_matrices=False) + + total_sum = S.sum() + cumulative = torch.cumsum(S, dim=0) + thresh = config.method.svd_thresh # svd threshold for the boosting + + # only apply boosting to non-empty matrices + if total_sum.item() != 0: + k = (cumulative / total_sum >= thresh).nonzero(as_tuple=False) + cutoff_idx = k[0].item() + + S_damped = torch.clamp(S, min=S[cutoff_idx]) + + new_vector[key] = (U * S_damped.unsqueeze(0)) @ Vh + ''' + + return new_vector + +def subspace_boosting( + merged_tv_state_dict: Dict[str, Any], + reset_thresh=20, # TODO: refactor the parameter list and just use the config + svd_thresh=0.01, + attn_svd_thresh=0.10, + cumsum=True, + remove_keys=[] + ) -> Dict[str, Any]: + """ + Subspace boosting for merging task vectors. + + Parameters: + tv_flat_checks: Flattened task vectors. + ptm_check: + Pretrained model. + config: + Configuration object containing method parameters (e.g., config.method.k, config.method.use_ties). + reset_thresh: default 20 + Threshold parameter used for ties merging. defaults to 20. + svd_thresh: default 0.01 + Threshold for singular value boosting. If cumsum is True, used as a cumulative ratio threshold; + otherwise used as a fraction of the total number of singular values. Defaults to 0.01. + cumsum: + Whether to use the cumulative sum approach for thresholding the singular values. + remove_keys: + Optional list of keys to remove from the state dict conversion. + + Returns: + A merged flat vector representing the task vector after subspace boosting. + + Raises: + ValueError: If the base_method is not one of the defined options. + """ + + # Merging approach for attention weight matrices + #apply_to_attn = config.method.apply_to_attn + # apply_to_attn=False: no subspace boosting for attention weights + #if apply_to_attn not in [False, "full_attn", "per_qkv", "per_head"]: + # raise ValueError(f"Apply to attention method {apply_to_attn} not defined.") + + keys_to_eval = [ + ".self_attn.q_proj.weight", + ".self_attn.k_proj.weight", + ".self_attn.v_proj.weight", + ".self_attn.o_proj.weight", + ".mlp.gate_proj.weight", + ".mlp.up_proj.weight", + ".mlp.down_proj.weight", + ] + + start_time = time.time_ns() + + for key, param in merged_tv_state_dict.items(): + if any(i in key for i in keys_to_eval) and isinstance(param, torch.Tensor): + logging.info(f"Applying subspace boosting to {key} with shape {param.shape}") + ''' + # Process attention weights per head or qkv + if keys_to_eval[0] in key: + if apply_to_attn == "per_head": + merged_tv_state_dict[key] = _per_head_subspace_boosting(param, config, config.method.attn_svd_thresh, cumsum) + elif apply_to_attn == "per_qkv": + merged_tv_state_dict[key] = _per_qkv_subspace_boosting(param, config, config.method.attn_svd_thresh, cumsum) + + # Process full attention weights and MLP weights + if apply_to_attn == "full_attn" or (keys_to_eval[0] not in key): + ''' + U, S, Vh = torch.linalg.svd(param, full_matrices=False) + + # Clamping approach using the cumulative sum of singular values as the threshold + if cumsum: + total_sum = S.sum() + cumulative = torch.cumsum(S, dim=0) + + # thresh = config.method.attn_svd_thresh if (keys_to_eval[0] in key) else svd_thresh + thresh = svd_thresh + + k = (cumulative / total_sum >= thresh).nonzero(as_tuple=False) + cutoff_idx = k[0].item() + + S_damped = torch.clamp(S, min=S[cutoff_idx]) + else: # Clamping approach using the threshold as an index + cutoff_idx = int(thresh * S.numel()) + S_damped = torch.clamp(S, min=S[cutoff_idx]) + + merged_tv_state_dict[key] = (U * S_damped.unsqueeze(0)) @ Vh + + end_time = time.time_ns() + + logging.info(f"Subspace Boosting took {(end_time - start_time) / 1_000_000} ms.") + + return merged_tv_state_dict \ No newline at end of file From 3c0a2c645fb000fd9bd0413bcddef773a5cc342f Mon Sep 17 00:00:00 2001 From: kaganhitit11 Date: Fri, 19 Sep 2025 13:30:30 +0200 Subject: [PATCH 02/14] adjusted layer names for llms --- mergekit/subspace_helpers.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mergekit/subspace_helpers.py b/mergekit/subspace_helpers.py index 2be2078e..63c24a81 100644 --- a/mergekit/subspace_helpers.py +++ b/mergekit/subspace_helpers.py @@ -9,7 +9,7 @@ def iso_c(task_vectors: List[Dict[str, Any]], device: torch.device) -> Dict[str, tvs = [task_vector[key] for task_vector in task_vectors] new_vector[key] = sum(tvs) / len(tvs) - if len(task_vectors[0][key].shape) == 2 and "text_projection" not in key: + if (len(task_vectors[0][key].shape) == 2 and "embed_tokens" not in key and "lm_head" not in key): new_vector[key] *= len(tvs) U, S, V = torch.linalg.svd(new_vector[key], full_matrices=False) S_mean = torch.ones_like(S) * S.mean() @@ -50,7 +50,8 @@ def compute_and_sum_svd_mem_reduction(task_vectors: List[Dict[str, Any]], device if ( len(task_vector[key].shape) == 2 - and "text_projection" not in key + and "embed_tokens" not in key + and "lm_head" not in key ): u, s, v = torch.linalg.svd(vec, full_matrices=False) @@ -79,7 +80,11 @@ def compute_and_sum_svd_mem_reduction(task_vectors: List[Dict[str, Any]], device else: new_vector[key] += (vec - new_vector[key]) / (i + 1) - if len(task_vector[key].shape) == 2 and "text_projection" not in key: + if ( + len(task_vector[key].shape) == 2 + and "embed_tokens" not in key + and "lm_head" not in key + ): u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False) u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False) From 19bc91e829ece477c6a7c6cc6166c8dc38900aaf Mon Sep 17 00:00:00 2001 From: kaganhitit11 Date: Fri, 19 Sep 2025 13:58:24 +0200 Subject: [PATCH 03/14] change the way tensors are passed to subspace methods --- .../merge_methods/generalized_task_arithmetic.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index 7af0c38e..586f2060 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -179,10 +179,18 @@ def execute( if self.lambda_ != 1: mixed_delta *= self.lambda_ + if self.lambda_ != 1: + mixed_delta *= self.lambda_ + + param_key = self.weight_info.name + subspace_input = [{param_key: tv["delta"]} for tv in tvs] + if self.method.name() == "iso_c": - mixed_delta = iso_c(deltas, deltas.device) + subspace_out = iso_c(subspace_input, deltas.device) + mixed_delta = subspace_out[param_key] elif self.method.name() == "tsvm": - mixed_delta = compute_and_sum_svd_mem_reduction(deltas, deltas.device) + subspace_out = compute_and_sum_svd_mem_reduction(subspace_input, deltas.device) + mixed_delta = subspace_out[param_key] elif self.method.name() in ["task_arithmetic_sb", "ties_sb"]: mixed_delta = subspace_boosting(mixed_delta) return (base + mixed_delta).to(base.dtype) From 1a46f20dc6f16b9ad97f5aebb00a769289d7a9f2 Mon Sep 17 00:00:00 2001 From: kaganhitit11 Date: Fri, 19 Sep 2025 14:03:33 +0200 Subject: [PATCH 04/14] minor fix --- mergekit/merge_methods/generalized_task_arithmetic.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index 586f2060..8fbaf0ee 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -179,9 +179,6 @@ def execute( if self.lambda_ != 1: mixed_delta *= self.lambda_ - if self.lambda_ != 1: - mixed_delta *= self.lambda_ - param_key = self.weight_info.name subspace_input = [{param_key: tv["delta"]} for tv in tvs] From a0cb81218d4b72f8284f1e3e394da57689d4311a Mon Sep 17 00:00:00 2001 From: kaganhitit11 Date: Fri, 19 Sep 2025 14:22:16 +0200 Subject: [PATCH 05/14] import error --- mergekit/merge_methods/generalized_task_arithmetic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index 8fbaf0ee..fece2a23 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -18,7 +18,7 @@ MergeTensorInput, ) from mergekit.sparsify import RescaleNorm, SparsificationMethod, sparsify -from mergekit.subspace_helpers import iso_c, compute_and_sum_svd_mem_reduction +from mergekit.subspace_helpers import iso_c, compute_and_sum_svd_mem_reduction, subspace_boosting class ConsensusMethod(str, Enum): count = "count" From 9e3d11cbce2d4d396388e7287795b4c44a9c778a Mon Sep 17 00:00:00 2001 From: kaganhitit11 Date: Fri, 19 Sep 2025 14:24:01 +0200 Subject: [PATCH 06/14] import error --- mergekit/subspace_helpers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mergekit/subspace_helpers.py b/mergekit/subspace_helpers.py index 63c24a81..a21df60c 100644 --- a/mergekit/subspace_helpers.py +++ b/mergekit/subspace_helpers.py @@ -1,5 +1,7 @@ import torch -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional +import time +import logging def iso_c(task_vectors: List[Dict[str, Any]], device: torch.device) -> Dict[str, Any]: print("Computing SVD...") From b33bbead7f25ec70fe301abef2c09c68f6a55244 Mon Sep 17 00:00:00 2001 From: kaganhitit11 Date: Fri, 19 Sep 2025 14:33:08 +0200 Subject: [PATCH 07/14] minor fix for subspace boosting --- mergekit/merge_methods/generalized_task_arithmetic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index fece2a23..ff3153f1 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -189,7 +189,8 @@ def execute( subspace_out = compute_and_sum_svd_mem_reduction(subspace_input, deltas.device) mixed_delta = subspace_out[param_key] elif self.method.name() in ["task_arithmetic_sb", "ties_sb"]: - mixed_delta = subspace_boosting(mixed_delta) + subspace_input = {param_key: mixed_delta} + mixed_delta = subspace_boosting(subspace_input) return (base + mixed_delta).to(base.dtype) def group_label(self) -> Optional[str]: From c83357188aeaa14bb7ab811b98ac2d69e9d04dcf Mon Sep 17 00:00:00 2001 From: kaganhitit11 Date: Fri, 19 Sep 2025 14:39:50 +0200 Subject: [PATCH 08/14] minor fix for subspace boosting --- mergekit/subspace_helpers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mergekit/subspace_helpers.py b/mergekit/subspace_helpers.py index a21df60c..1bea917a 100644 --- a/mergekit/subspace_helpers.py +++ b/mergekit/subspace_helpers.py @@ -172,7 +172,8 @@ def subspace_boosting( ] start_time = time.time_ns() - + print('Length of merged_tv_state_dict:', len(merged_tv_state_dict.keys())) + global_key = merged_tv_state_dict.keys()[0] for key, param in merged_tv_state_dict.items(): if any(i in key for i in keys_to_eval) and isinstance(param, torch.Tensor): logging.info(f"Applying subspace boosting to {key} with shape {param.shape}") @@ -211,4 +212,4 @@ def subspace_boosting( logging.info(f"Subspace Boosting took {(end_time - start_time) / 1_000_000} ms.") - return merged_tv_state_dict \ No newline at end of file + return torch.tensor(merged_tv_state_dict[global_key]) \ No newline at end of file From c6a4e18ddbf69f44c8a94c2328656a370ede4a5e Mon Sep 17 00:00:00 2001 From: kaganhitit11 Date: Fri, 19 Sep 2025 14:45:03 +0200 Subject: [PATCH 09/14] updates for subspace boosting --- .../generalized_task_arithmetic.py | 3 +- mergekit/subspace_helpers.py | 70 +++++++++---------- 2 files changed, 35 insertions(+), 38 deletions(-) diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index ff3153f1..c83696ff 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -189,8 +189,7 @@ def execute( subspace_out = compute_and_sum_svd_mem_reduction(subspace_input, deltas.device) mixed_delta = subspace_out[param_key] elif self.method.name() in ["task_arithmetic_sb", "ties_sb"]: - subspace_input = {param_key: mixed_delta} - mixed_delta = subspace_boosting(subspace_input) + mixed_delta = subspace_boosting(param_key, mixed_delta) return (base + mixed_delta).to(base.dtype) def group_label(self) -> Optional[str]: diff --git a/mergekit/subspace_helpers.py b/mergekit/subspace_helpers.py index 1bea917a..ecd0b464 100644 --- a/mergekit/subspace_helpers.py +++ b/mergekit/subspace_helpers.py @@ -122,7 +122,8 @@ def compute_and_sum_svd_mem_reduction(task_vectors: List[Dict[str, Any]], device return new_vector def subspace_boosting( - merged_tv_state_dict: Dict[str, Any], + merged_tv_key: str, + merged_tv: torch.Tensor, reset_thresh=20, # TODO: refactor the parameter list and just use the config svd_thresh=0.01, attn_svd_thresh=0.10, @@ -172,44 +173,41 @@ def subspace_boosting( ] start_time = time.time_ns() - print('Length of merged_tv_state_dict:', len(merged_tv_state_dict.keys())) - global_key = merged_tv_state_dict.keys()[0] - for key, param in merged_tv_state_dict.items(): - if any(i in key for i in keys_to_eval) and isinstance(param, torch.Tensor): - logging.info(f"Applying subspace boosting to {key} with shape {param.shape}") - ''' - # Process attention weights per head or qkv - if keys_to_eval[0] in key: - if apply_to_attn == "per_head": - merged_tv_state_dict[key] = _per_head_subspace_boosting(param, config, config.method.attn_svd_thresh, cumsum) - elif apply_to_attn == "per_qkv": - merged_tv_state_dict[key] = _per_qkv_subspace_boosting(param, config, config.method.attn_svd_thresh, cumsum) + if any(i in merged_tv_key for i in keys_to_eval) and isinstance(merged_tv, torch.Tensor): + logging.info(f"Applying subspace boosting to {merged_tv_key} with shape {merged_tv.shape}") + ''' + # Process attention weights per head or qkv + if keys_to_eval[0] in key: + if apply_to_attn == "per_head": + merged_tv_state_dict[key] = _per_head_subspace_boosting(param, config, config.method.attn_svd_thresh, cumsum) + elif apply_to_attn == "per_qkv": + merged_tv_state_dict[key] = _per_qkv_subspace_boosting(param, config, config.method.attn_svd_thresh, cumsum) + + # Process full attention weights and MLP weights + if apply_to_attn == "full_attn" or (keys_to_eval[0] not in key): + ''' + U, S, Vh = torch.linalg.svd(merged_tv, full_matrices=False) + + # Clamping approach using the cumulative sum of singular values as the threshold + if cumsum: + total_sum = S.sum() + cumulative = torch.cumsum(S, dim=0) - # Process full attention weights and MLP weights - if apply_to_attn == "full_attn" or (keys_to_eval[0] not in key): - ''' - U, S, Vh = torch.linalg.svd(param, full_matrices=False) - - # Clamping approach using the cumulative sum of singular values as the threshold - if cumsum: - total_sum = S.sum() - cumulative = torch.cumsum(S, dim=0) - - # thresh = config.method.attn_svd_thresh if (keys_to_eval[0] in key) else svd_thresh - thresh = svd_thresh - - k = (cumulative / total_sum >= thresh).nonzero(as_tuple=False) - cutoff_idx = k[0].item() - - S_damped = torch.clamp(S, min=S[cutoff_idx]) - else: # Clamping approach using the threshold as an index - cutoff_idx = int(thresh * S.numel()) - S_damped = torch.clamp(S, min=S[cutoff_idx]) - - merged_tv_state_dict[key] = (U * S_damped.unsqueeze(0)) @ Vh + # thresh = config.method.attn_svd_thresh if (keys_to_eval[0] in key) else svd_thresh + thresh = svd_thresh + + k = (cumulative / total_sum >= thresh).nonzero(as_tuple=False) + cutoff_idx = k[0].item() + + S_damped = torch.clamp(S, min=S[cutoff_idx]) + else: # Clamping approach using the threshold as an index + cutoff_idx = int(thresh * S.numel()) + S_damped = torch.clamp(S, min=S[cutoff_idx]) + + merged_tv = (U * S_damped.unsqueeze(0)) @ Vh end_time = time.time_ns() logging.info(f"Subspace Boosting took {(end_time - start_time) / 1_000_000} ms.") - return torch.tensor(merged_tv_state_dict[global_key]) \ No newline at end of file + return merged_tv \ No newline at end of file From 4b2531463d4afb65df646056c8ed31ab88a4be8c Mon Sep 17 00:00:00 2001 From: kaganhitit11 Date: Fri, 19 Sep 2025 14:49:33 +0200 Subject: [PATCH 10/14] updates for subspace boosting --- mergekit/subspace_helpers.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/mergekit/subspace_helpers.py b/mergekit/subspace_helpers.py index ecd0b464..0c6e8e5f 100644 --- a/mergekit/subspace_helpers.py +++ b/mergekit/subspace_helpers.py @@ -175,36 +175,34 @@ def subspace_boosting( start_time = time.time_ns() if any(i in merged_tv_key for i in keys_to_eval) and isinstance(merged_tv, torch.Tensor): logging.info(f"Applying subspace boosting to {merged_tv_key} with shape {merged_tv.shape}") - ''' - # Process attention weights per head or qkv - if keys_to_eval[0] in key: - if apply_to_attn == "per_head": - merged_tv_state_dict[key] = _per_head_subspace_boosting(param, config, config.method.attn_svd_thresh, cumsum) - elif apply_to_attn == "per_qkv": - merged_tv_state_dict[key] = _per_qkv_subspace_boosting(param, config, config.method.attn_svd_thresh, cumsum) - # Process full attention weights and MLP weights - if apply_to_attn == "full_attn" or (keys_to_eval[0] not in key): - ''' - U, S, Vh = torch.linalg.svd(merged_tv, full_matrices=False) + # Store original dtype + original_dtype = merged_tv.dtype + + # Convert to float32 for SVD + merged_tv_fp32 = merged_tv.to(torch.float32) + + U, S, Vh = torch.linalg.svd(merged_tv_fp32, full_matrices=False) - # Clamping approach using the cumulative sum of singular values as the threshold if cumsum: total_sum = S.sum() cumulative = torch.cumsum(S, dim=0) - # thresh = config.method.attn_svd_thresh if (keys_to_eval[0] in key) else svd_thresh thresh = svd_thresh k = (cumulative / total_sum >= thresh).nonzero(as_tuple=False) cutoff_idx = k[0].item() S_damped = torch.clamp(S, min=S[cutoff_idx]) - else: # Clamping approach using the threshold as an index + else: # Clamping approach using the threshold as an index cutoff_idx = int(thresh * S.numel()) S_damped = torch.clamp(S, min=S[cutoff_idx]) + # Perform matrix multiplication in FP32 merged_tv = (U * S_damped.unsqueeze(0)) @ Vh + + # Convert back to original dtype + merged_tv = merged_tv.to(original_dtype) end_time = time.time_ns() From 1718ee4d5153e3ce80f5367c9102c947944fedd0 Mon Sep 17 00:00:00 2001 From: kaganhitit11 Date: Fri, 19 Sep 2025 17:08:35 +0200 Subject: [PATCH 11/14] updated tsvm, iso_c, subspace boosting to work with mergekit --- .../generalized_task_arithmetic.py | 9 +- mergekit/subspace_helpers.py | 174 +++++++++--------- 2 files changed, 86 insertions(+), 97 deletions(-) diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index c83696ff..6db2ec17 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -180,16 +180,15 @@ def execute( mixed_delta *= self.lambda_ param_key = self.weight_info.name - subspace_input = [{param_key: tv["delta"]} for tv in tvs] + subspace_input = [tv["delta"] for tv in tvs] if self.method.name() == "iso_c": - subspace_out = iso_c(subspace_input, deltas.device) - mixed_delta = subspace_out[param_key] + mixed_delta = iso_c(subspace_input, param_key, deltas.device) elif self.method.name() == "tsvm": - subspace_out = compute_and_sum_svd_mem_reduction(subspace_input, deltas.device) - mixed_delta = subspace_out[param_key] + mixed_delta = compute_and_sum_svd_mem_reduction(subspace_input, param_key, deltas.device) elif self.method.name() in ["task_arithmetic_sb", "ties_sb"]: mixed_delta = subspace_boosting(param_key, mixed_delta) + return (base + mixed_delta).to(base.dtype) def group_label(self) -> Optional[str]: diff --git a/mergekit/subspace_helpers.py b/mergekit/subspace_helpers.py index 0c6e8e5f..6c1e8ff4 100644 --- a/mergekit/subspace_helpers.py +++ b/mergekit/subspace_helpers.py @@ -3,32 +3,34 @@ import time import logging -def iso_c(task_vectors: List[Dict[str, Any]], device: torch.device) -> Dict[str, Any]: - print("Computing SVD...") +def iso_c(task_vectors: List[torch.Tensor], tv_key: str, device: torch.device) -> Dict[str, Any]: with torch.no_grad(): - new_vector = {} - for key in task_vectors[0]: - tvs = [task_vector[key] for task_vector in task_vectors] - new_vector[key] = sum(tvs) / len(tvs) - - if (len(task_vectors[0][key].shape) == 2 and "embed_tokens" not in key and "lm_head" not in key): - new_vector[key] *= len(tvs) - U, S, V = torch.linalg.svd(new_vector[key], full_matrices=False) - S_mean = torch.ones_like(S) * S.mean() - - new_vector[key] = torch.linalg.multi_dot( - ( - U, - torch.diag(S_mean), - V, - ) + tvs = task_vectors + new_vector = sum(tvs) / len(tvs) + original_dtype = new_vector.dtype # Store original dtype + + if (len(task_vectors[0].shape) == 2 and "embed_tokens" not in tv_key and "lm_head" not in tv_key): + print(f"Computing SVD for {tv_key}... with shape {task_vectors[0].shape}") + new_vector *= len(tvs) + # Convert to float32 for SVD + vec_fp32 = new_vector.to(torch.float32) + U, S, V = torch.linalg.svd(vec_fp32, full_matrices=False) + S_mean = torch.ones_like(S) * S.mean() + + # Perform matrix multiplication in float32 and convert back to original dtype + new_vector = torch.linalg.multi_dot( + ( + U, + torch.diag(S_mean), + V, ) + ).to(original_dtype) # Convert back to original dtype return new_vector ############### #### TSV Merge Orthogonalization -def compute_and_sum_svd_mem_reduction(task_vectors: List[Dict[str, Any]], device: torch.device) -> Dict[str, Any]: +def compute_and_sum_svd_mem_reduction(task_vectors: List[torch.Tensor], tv_key: str, device: torch.device) -> Dict[str, Any]: """ Computes the Singular Value Decomposition (SVD) for each vector in the task_vectors, reduces the dimensionality of the vectors based on the sv_reduction factor, and concatenate @@ -42,82 +44,65 @@ def compute_and_sum_svd_mem_reduction(task_vectors: List[Dict[str, Any]], device dict: A dictionary containing the new vectors after SVD computation and merging. """ sv_reduction = 1 / len(task_vectors) - print("Computing SVD...") with torch.no_grad(): new_vector = {} - for key in task_vectors[0]: - new_vector[key] = {} - for i, task_vector in enumerate(task_vectors): - vec = task_vector[key] - - if ( - len(task_vector[key].shape) == 2 - and "embed_tokens" not in key - and "lm_head" not in key - ): - u, s, v = torch.linalg.svd(vec, full_matrices=False) - - if i == 0: - print(f"Computed SVD for {key}...") - sum_u = torch.zeros_like(u, device=device) - sum_s = torch.zeros_like(s, device=device) - sum_v = torch.zeros_like(v, device=device) - reduced_index_s = int(s.shape[0] * sv_reduction) - - # select only the first reduced_index_s columns of u and place them - sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[ - :, :reduced_index_s - ] - sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[ - :reduced_index_s - ] - # select only the first reduced_index_s rows of v and place them - sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[ - :reduced_index_s, : - ] - - else: - if i == 0: - new_vector[key] = vec.clone() - else: - new_vector[key] += (vec - new_vector[key]) / (i + 1) + for i, task_vector in enumerate(task_vectors): + vec = task_vector + original_dtype = vec.dtype # Store original dtype if ( - len(task_vector[key].shape) == 2 - and "embed_tokens" not in key - and "lm_head" not in key + len(task_vector.shape) == 2 + and "embed_tokens" not in tv_key + and "lm_head" not in tv_key ): - u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False) - u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False) - - new_vector[key] = torch.linalg.multi_dot( - ( - u_u, - v_u, - torch.diag(sum_s), - u_v, - v_v, - ) + print(f"Computing SVD for {tv_key}... with shape {task_vector.shape}") + # Convert to float32 for SVD + vec_fp32 = vec.to(torch.float32) + u, s, v = torch.linalg.svd(vec_fp32, full_matrices=False) + + if i == 0: + sum_u = torch.zeros_like(u, device=device, dtype=torch.float32) + sum_s = torch.zeros_like(s, device=device, dtype=torch.float32) + sum_v = torch.zeros_like(v, device=device, dtype=torch.float32) + reduced_index_s = int(s.shape[0] * sv_reduction) + + # select only the first reduced_index_s columns of u and place them + sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[ + :, :reduced_index_s + ] + sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[ + :reduced_index_s + ] + # select only the first reduced_index_s rows of v and place them + sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[ + :reduced_index_s, : + ] + + else: + if i == 0: + new_vector = vec.clone() + else: + new_vector += (vec - new_vector) / (i + 1) + + if ( + len(task_vector.shape) == 2 + and "embed_tokens" not in tv_key + and "lm_head" not in tv_key + ): + # Perform final SVD operations in float32 + u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False) + u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False) + + # Perform matrix multiplication in float32 + new_vector = torch.linalg.multi_dot( + ( + u_u, + v_u, + torch.diag(sum_s), + u_v, + v_v, ) - ''' - - if config.method.apply_subspace_boosting: - print("Applying subspace boosting...") - U, S, Vh = torch.linalg.svd(new_vector[key], full_matrices=False) - - total_sum = S.sum() - cumulative = torch.cumsum(S, dim=0) - thresh = config.method.svd_thresh # svd threshold for the boosting - - # only apply boosting to non-empty matrices - if total_sum.item() != 0: - k = (cumulative / total_sum >= thresh).nonzero(as_tuple=False) - cutoff_idx = k[0].item() - - S_damped = torch.clamp(S, min=S[cutoff_idx]) - - new_vector[key] = (U * S_damped.unsqueeze(0)) @ Vh - ''' + ).to(original_dtype) # Convert back to original dtype return new_vector @@ -171,10 +156,15 @@ def subspace_boosting( ".mlp.up_proj.weight", ".mlp.down_proj.weight", ] - + ''' + print('merged_tv_key: ', merged_tv_key) + print('type(merged_tv_key): ', type(merged_tv_key)) + print('type(merged_tv): ', type(merged_tv)) + print('merged_tv.shape: ', merged_tv.shape) + ''' start_time = time.time_ns() if any(i in merged_tv_key for i in keys_to_eval) and isinstance(merged_tv, torch.Tensor): - logging.info(f"Applying subspace boosting to {merged_tv_key} with shape {merged_tv.shape}") + print(f"Applying subspace boosting to {merged_tv_key} with shape {merged_tv.shape}") # Store original dtype original_dtype = merged_tv.dtype @@ -206,6 +196,6 @@ def subspace_boosting( end_time = time.time_ns() - logging.info(f"Subspace Boosting took {(end_time - start_time) / 1_000_000} ms.") + # print(f"Subspace Boosting took {(end_time - start_time) / 1_000_000} ms.") return merged_tv \ No newline at end of file From 7dcd281ea60732b0157cbdf1dd3db02690940826 Mon Sep 17 00:00:00 2001 From: kaganhitit11 Date: Sun, 12 Oct 2025 03:08:41 +0300 Subject: [PATCH 12/14] updated subspace boosting for handling cutoff_idx not found error --- mergekit/subspace_helpers.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mergekit/subspace_helpers.py b/mergekit/subspace_helpers.py index 0c6e8e5f..98f407a6 100644 --- a/mergekit/subspace_helpers.py +++ b/mergekit/subspace_helpers.py @@ -191,7 +191,13 @@ def subspace_boosting( thresh = svd_thresh k = (cumulative / total_sum >= thresh).nonzero(as_tuple=False) - cutoff_idx = k[0].item() + + if k.numel() == 0: + # fallback: use smallest singular value + cutoff_idx = -1 + print(f"[Warning] No valid SVD cutoff for {merged_tv_key}. Using full singular spectrum.") + else: + cutoff_idx = k[0].item() S_damped = torch.clamp(S, min=S[cutoff_idx]) else: # Clamping approach using the threshold as an index From 22b680ad69c00e364b1e8ec27738217a5cf65ca3 Mon Sep 17 00:00:00 2001 From: kaganhitit11 Date: Sun, 12 Oct 2025 03:54:04 +0300 Subject: [PATCH 13/14] added gesvd driver fallback for tsvm --- mergekit/subspace_helpers.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/mergekit/subspace_helpers.py b/mergekit/subspace_helpers.py index eeef4f06..72a5abb4 100644 --- a/mergekit/subspace_helpers.py +++ b/mergekit/subspace_helpers.py @@ -90,8 +90,18 @@ def compute_and_sum_svd_mem_reduction(task_vectors: List[torch.Tensor], tv_key: and "lm_head" not in tv_key ): # Perform final SVD operations in float32 - u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False) - u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False) + + try: + u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False) + except torch._C._LinAlgError: + print(f"[Retry with 'gesvd'] SVD failed for {tv_key}.") + u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False, driver='gesvd') + + try: + u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False) + except torch._C._LinAlgError: + print(f"[Retry with 'gesvd'] SVD failed for {tv_key}.") + u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False, driver='gesvd') # Perform matrix multiplication in float32 new_vector = torch.linalg.multi_dot( From 3f031f34082c8a944311291bff5173bd19f78aad Mon Sep 17 00:00:00 2001 From: kaganhitit11 Date: Tue, 21 Oct 2025 15:45:43 +0300 Subject: [PATCH 14/14] added svd thresh and cumsum as hyperparams for subspace boosting --- .../merge_methods/generalized_task_arithmetic.py | 8 +++++++- mergekit/subspace_helpers.py | 15 +-------------- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index 6db2ec17..c5216705 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -55,6 +55,8 @@ def parameters(self) -> List[ConfigParameterDef]: name="rescale", required=False, default_value=self.default_rescale ), ConfigParameterDef(name="lambda", required=False, default_value=1.0), + ConfigParameterDef(name="svd_thresh", required=False, default_value=0.01), + ConfigParameterDef(name="cumsum", required=False, default_value=True), ] def tensor_parameters(self) -> List[ConfigParameterDef]: @@ -96,6 +98,8 @@ def make_task( lambda_=parameters["lambda"], rescale_norm=RescaleNorm.l1 if parameters["rescale"] else None, weight_info=output_weight, + svd_thresh=parameters["svd_thresh"], + cumsum=parameters["cumsum"], ) @@ -109,6 +113,8 @@ class GTATask(Task[torch.Tensor]): normalize: bool lambda_: float rescale_norm: Optional[RescaleNorm] + svd_thresh: float + cumsum: bool def uses_accelerator(self) -> bool: return True @@ -187,7 +193,7 @@ def execute( elif self.method.name() == "tsvm": mixed_delta = compute_and_sum_svd_mem_reduction(subspace_input, param_key, deltas.device) elif self.method.name() in ["task_arithmetic_sb", "ties_sb"]: - mixed_delta = subspace_boosting(param_key, mixed_delta) + mixed_delta = subspace_boosting(param_key, mixed_delta, svd_thresh=self.svd_thresh, cumsum=self.cumsum) return (base + mixed_delta).to(base.dtype) diff --git a/mergekit/subspace_helpers.py b/mergekit/subspace_helpers.py index 72a5abb4..0a6f7b52 100644 --- a/mergekit/subspace_helpers.py +++ b/mergekit/subspace_helpers.py @@ -119,11 +119,8 @@ def compute_and_sum_svd_mem_reduction(task_vectors: List[torch.Tensor], tv_key: def subspace_boosting( merged_tv_key: str, merged_tv: torch.Tensor, - reset_thresh=20, # TODO: refactor the parameter list and just use the config svd_thresh=0.01, - attn_svd_thresh=0.10, cumsum=True, - remove_keys=[] ) -> Dict[str, Any]: """ Subspace boosting for merging task vectors. @@ -166,13 +163,7 @@ def subspace_boosting( ".mlp.up_proj.weight", ".mlp.down_proj.weight", ] - ''' - print('merged_tv_key: ', merged_tv_key) - print('type(merged_tv_key): ', type(merged_tv_key)) - print('type(merged_tv): ', type(merged_tv)) - print('merged_tv.shape: ', merged_tv.shape) - ''' - start_time = time.time_ns() + if any(i in merged_tv_key for i in keys_to_eval) and isinstance(merged_tv, torch.Tensor): print(f"Applying subspace boosting to {merged_tv_key} with shape {merged_tv.shape}") @@ -210,8 +201,4 @@ def subspace_boosting( # Convert back to original dtype merged_tv = merged_tv.to(original_dtype) - end_time = time.time_ns() - - # print(f"Subspace Boosting took {(end_time - start_time) / 1_000_000} ms.") - return merged_tv \ No newline at end of file