@@ -39,6 +39,7 @@ class SftArguments:
3939
4040 sft_type : Literal ['lora' , 'full' , 'longlora' , 'qalora' ] = 'lora'
4141 freeze_parameters : float = 0. # 0 ~ 1
42+ additional_trainable_parameters : List [str ] = field (default_factory = list )
4243 tuner_backend : Literal ['swift' , 'peft' ] = 'swift'
4344 template_type : str = field (
4445 default = 'AUTO' ,
@@ -211,6 +212,9 @@ def __post_init__(self) -> None:
211212 assert self .freeze_parameters == 0. , (
212213 'lora does not support `freeze_parameters`, please set `--sft_type full`'
213214 )
215+ assert len (self .additional_trainable_parameters ) == 0 , (
216+ 'lora does not support `additional_trainable_parameters`, please set `--sft_type full`'
217+ )
214218 if 'int4' in self .model_type or 'int8' in self .model_type :
215219 assert self .quantization_bit == 0 , 'int4 and int8 models do not need to be quantized again.'
216220 if self .learning_rate is None :
@@ -221,12 +225,16 @@ def __post_init__(self) -> None:
221225 else :
222226 self .only_save_model = True
223227 elif self .sft_type == 'full' :
224- assert 0 <= self .freeze_parameters < 1
228+ assert 0 <= self .freeze_parameters <= 1
225229 assert self .quantization_bit == 0 , 'Full parameter fine-tuning does not support quantization.'
226230 assert self .dtype != 'fp16' , (
227231 "Fine-tuning with dtype=='fp16' can lead to NaN issues. "
228232 'Please use fp32+AMP or bf16 to perform full parameter fine-tuning.'
229233 )
234+ if isinstance (self .additional_trainable_parameters , str ):
235+ self .additional_trainable_parameters = [
236+ self .additional_trainable_parameters
237+ ]
230238 if self .learning_rate is None :
231239 self .learning_rate = 2e-5
232240 if self .only_save_model is None :
0 commit comments