Skip to content
28 changes: 28 additions & 0 deletions mergekit/merge_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,55 @@ def get(method: str) -> MergeMethod:
sparsification_method=None,
default_normalize=False,
default_rescale=False,
default_swapping=False,
)
elif method == "ties":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.magnitude,
default_normalize=True,
default_rescale=False,
default_swapping=False,
)
elif method == "dare_ties":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.random,
default_normalize=False,
default_rescale=True,
default_swapping=False,
)
elif method == "dare_linear":
return GeneralizedTaskArithmeticMerge(
consensus_method=None,
sparsification_method=SparsificationMethod.random,
default_normalize=False,
default_rescale=True,
default_swapping=False,
)
elif method == "task_swapping":
return GeneralizedTaskArithmeticMerge(
consensus_method=None,
sparsification_method=None,
default_normalize=False,
default_rescale=False,
default_swapping=True,
)
elif method == "task_swapping_ties":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.magnitude,
default_normalize=True,
default_rescale=False,
default_swapping=True,
)
elif method == "task_swapping_dare_ties":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.rescaled_random,
default_normalize=False,
default_rescale=True,
default_swapping=True,
)
elif method == "breadcrumbs":
return GeneralizedTaskArithmeticMerge(
Expand Down
72 changes: 71 additions & 1 deletion mergekit/merge_methods/generalized_task_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class GeneralizedTaskArithmeticMerge(MergeMethod, BaseModel, frozen=True):
sparsification_method: Optional[SparsificationMethod]
default_normalize: bool
default_rescale: bool
default_swapping: bool

def parameters(self) -> List[ConfigParameterDef]:
return [
Expand All @@ -52,12 +53,19 @@ def parameters(self) -> List[ConfigParameterDef]:
ConfigParameterDef(
name="rescale", required=False, default_value=self.default_rescale
),
ConfigParameterDef(
name="swapping", required=False, default_value=self.default_swapping
),
]

def tensor_parameters(self) -> List[ConfigParameterDef]:
res = [
ConfigParameterDef(name="weight", required=True),
ConfigParameterDef(name="density", required=False, default_value=1.0),
ConfigParameterDef(name="diagonal_offset", required=False),
ConfigParameterDef(name="invert_offset", required=False, default_value= False),
ConfigParameterDef(name="random_mask", required=False, default_value= 0.0),
ConfigParameterDef(name="random_mask_seed", required=False, default_value= None),
]
if self.sparsification_method == SparsificationMethod.magnitude_outliers:
res.append(
Expand Down Expand Up @@ -113,6 +121,7 @@ def make_task(
int8_mask=parameters["int8_mask"],
normalize=parameters["normalize"],
rescale=parameters["rescale"],
swapping=parameters["swapping"],
weight_info=output_weight,
)

Expand All @@ -126,6 +135,7 @@ class GTATask(Task[torch.Tensor]):
int8_mask: bool
normalize: bool
rescale: bool
swapping: bool

def uses_accelerator(self) -> bool:
return True
Expand All @@ -144,6 +154,7 @@ def execute(
self.base_model,
tensors,
tensor_parameters=self.tensor_parameters.data,
swapping=self.swapping,
)
if not tvs:
return base
Expand All @@ -158,6 +169,7 @@ def execute(
if "gamma" in tv_info:
kwargs["gamma"] = tv_info["gamma"]


if "epsilon" in tv_info:
kwargs["epsilon"] = tv_info["epsilon"]

Expand Down Expand Up @@ -226,15 +238,68 @@ def group_label(self) -> Optional[str]:
return self.tensors.group_label()


def swapping_method(base, x, parameters):
def swap_values(shape, n, base, x):
if x.dim() == 2:
rows, cols = shape
rows_range = torch.arange(rows).view(-1, 1)
cols_range = torch.arange(cols).view(1, -1)
mask = ((rows_range + cols_range) % n == 0).to(base.device.type).bool()
x = torch.where(mask, x, base)
else:
rows_range = torch.arange(shape[0])
mask = ((rows_range) % n == 0).to(base.device.type).bool()
x = torch.where(mask, x, base)
return x

def rand_mask(base, x, percent, seed=None):
oldseed = torch.seed()
if seed is not None:
torch.manual_seed(seed)
random = torch.rand(base.shape)
mask = (random <= percent).to(base.device.type).bool()
del random
torch.manual_seed(oldseed)
x = torch.where(mask, x, base)
return x

bt = base.dtype
if x.device.type == "cpu":
x = x.to(torch.float32)
base = base.to(torch.float32)

diagonal_offset = None
diagonal_offset = parameters.get('diagonal_offset')
random_mask = parameters.get('random_mask')
random_mask_seed = parameters.get('random_mask_seed')
random_mask_seed = int(random_mask_seed) if random_mask_seed is not None else random_mask_seed

assert (diagonal_offset is not None) and (diagonal_offset % 1 == 0) and (diagonal_offset >= 2), "The diagonal_offset must be an integer greater than or equal to 2."

if random_mask != 0.0:
assert (random_mask is not None) and (random_mask < 1.0) and (random_mask > 0.0) , "The random_mask parameter can't be empty, 0, 1, or None, it must be a number between 0 and 1."
assert random_mask_seed is None or (isinstance(random_mask_seed, int) and random_mask_seed % 1 == 0), "The random_mask_seed parameter must be None or an integer, None is a random seed."
x = rand_mask(base, x, random_mask, random_mask_seed)

else:
if parameters.get('invert_offset') == False:
x = swap_values(x.shape, diagonal_offset, base, x)
else:
x = swap_values(x.shape, diagonal_offset, x, base)

del base
return x.to(bt)


def get_task_vectors(
weight_info: WeightInfo,
base_model: ModelReference,
tensors: ImmutableMap[ModelReference, torch.Tensor],
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
swapping: bool,
) -> Tuple[List[Dict[str, Any]], torch.Tensor]:
keys = list(tensors.keys())
base = tensors[base_model]

parameter_name = weight_info.name

res = []
Expand All @@ -243,6 +308,7 @@ def get_task_vectors(
continue

x = tensors[model].to(base.dtype)

if x.shape != base.shape:
if weight_info.is_embed:
x = x[: base.shape[0], : base.shape[1]]
Expand All @@ -253,6 +319,10 @@ def get_task_vectors(
)
continue

if swapping:
x = swapping_method(base, x, dict(tensor_parameters[model].items()))


delta = x - base
del x
del tensors[model]
Expand Down