@@ -64,6 +64,13 @@ def tensor_parameters(self) -> List[ConfigParameterDef]:
6464 ConfigParameterDef (name = "random_mask" , required = False , default_value = 0.0 ),
6565 ConfigParameterDef (name = "random_mask_seed" , required = False , default_value = None ),
6666 ]
67+ if self .sparsification_method == SparsificationMethod .magnitude_outliers :
68+ res .append (
69+ ConfigParameterDef (
70+ name = "gamma" ,
71+ default_value = 0.01 ,
72+ )
73+ )
6774 return res
6875
6976 def make_task (
@@ -83,15 +90,15 @@ def make_task(
8390 normalize = parameters ["normalize" ],
8491 rescale = parameters ["rescale" ],
8592 swapping = parameters ["swapping" ],
86- out_tensor_name = output_weight . name ,
93+ weight_info = output_weight ,
8794 )
8895
8996
9097class GTATask (Task [torch .Tensor ]):
9198 method : GeneralizedTaskArithmeticMerge
9299 tensors : GatherTensors
93100 base_model : ModelReference
94- out_tensor_name : str
101+ weight_info : WeightInfo
95102 tensor_parameters : ImmutableMap [ModelReference , Any ]
96103 int8_mask : bool
97104 normalize : bool
@@ -111,7 +118,7 @@ def execute(
111118 ) -> torch .Tensor :
112119 # collect task vectors
113120 tvs , base = get_task_vectors (
114- self .out_tensor_name ,
121+ self .weight_info ,
115122 self .base_model ,
116123 tensors ,
117124 tensor_parameters = self .tensor_parameters .data ,
@@ -123,11 +130,15 @@ def execute(
123130 # sparsify
124131 if self .method .sparsification_method :
125132 for tv_info in tvs :
133+ kwargs = {}
134+ if "gamma" in tv_info :
135+ kwargs ["gamma" ] = tv_info ["gamma" ]
126136 tv_info ["delta" ] = sparsify (
127137 tv_info ["delta" ],
128138 density = tv_info ["density" ],
129139 method = self .method .sparsification_method ,
130140 rescale = self .rescale ,
141+ ** kwargs ,
131142 )
132143
133144 deltas = torch .stack ([tv ["delta" ] for tv in tvs ], dim = 0 )
@@ -218,14 +229,15 @@ def rand_mask(base, x, percent, seed=None):
218229
219230
220231def get_task_vectors (
221- parameter_name : str ,
232+ weight_info : WeightInfo ,
222233 base_model : ModelReference ,
223234 tensors : ImmutableMap [ModelReference , torch .Tensor ],
224235 tensor_parameters : ImmutableMap [ModelReference , ImmutableMap [str , Any ]],
225236 swapping : bool ,
226237) -> Tuple [List [Dict [str , Any ]], torch .Tensor ]:
227238 keys = list (tensors .keys ())
228239 base = tensors [base_model ]
240+ parameter_name = weight_info .name
229241
230242 res = []
231243 for model in keys :
@@ -235,7 +247,7 @@ def get_task_vectors(
235247 x = tensors [model ].to (base .dtype )
236248
237249 if x .shape != base .shape :
238- if "lm_head" in parameter_name or "embed_tokens" in parameter_name :
250+ if weight_info . is_embed :
239251 x = x [: base .shape [0 ], : base .shape [1 ]]
240252 logging .warning (f"Using submatrix of { model } :{ parameter_name } " )
241253 else :
0 commit comments