@@ -579,7 +579,7 @@ class SftArguments(ArgumentsBase):
579579 max_grad_norm : float = 0.5
580580 predict_with_generate : bool = False
581581 lr_scheduler_type : str = 'cosine'
582- lr_scheduler_kwargs : Dict [str , Any ] = field ( default_factory = dict )
582+ lr_scheduler_kwargs : Optional [str ] = None # json
583583 warmup_ratio : float = 0.05
584584
585585 eval_steps : int = 50
@@ -732,6 +732,12 @@ def _prepare_target_modules(self, target_modules) -> List[str]:
732732 self .lora_use_all = True
733733 return target_modules
734734
735+ def handle_lr_scheduler_kwargs (self ):
736+ if self .lr_scheduler_kwargs is None :
737+ self .lr_scheduler_kwargs = {}
738+ elif isinstance (self .lr_scheduler_kwargs , str ):
739+ self .lr_scheduler_kwargs = json .loads (self .lr_scheduler_kwargs )
740+
735741 def _prepare_modules_to_save (self , modules_to_save ) -> List [str ]:
736742 if isinstance (modules_to_save , str ):
737743 modules_to_save = [modules_to_save ]
@@ -782,6 +788,7 @@ def __post_init__(self) -> None:
782788 self .set_model_type ()
783789 self .check_flash_attn ()
784790 self .handle_generation_config ()
791+ self .handle_lr_scheduler_kwargs ()
785792 self .is_multimodal = self ._is_multimodal (self .model_type )
786793
787794 self .lora_use_embedding = False
0 commit comments