Skip to content

Commit a76c206

Browse files
authored
support lr_scheduler_kwargs (#1310)
1 parent 8727f2f commit a76c206

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

swift/llm/utils/argument.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ class SftArguments(ArgumentsBase):
579579
max_grad_norm: float = 0.5
580580
predict_with_generate: bool = False
581581
lr_scheduler_type: str = 'cosine'
582-
lr_scheduler_kwargs: Dict[str, Any] = field(default_factory=dict)
582+
lr_scheduler_kwargs: Optional[str] = None # json
583583
warmup_ratio: float = 0.05
584584

585585
eval_steps: int = 50
@@ -732,6 +732,12 @@ def _prepare_target_modules(self, target_modules) -> List[str]:
732732
self.lora_use_all = True
733733
return target_modules
734734

735+
def handle_lr_scheduler_kwargs(self):
736+
if self.lr_scheduler_kwargs is None:
737+
self.lr_scheduler_kwargs = {}
738+
elif isinstance(self.lr_scheduler_kwargs, str):
739+
self.lr_scheduler_kwargs = json.loads(self.lr_scheduler_kwargs)
740+
735741
def _prepare_modules_to_save(self, modules_to_save) -> List[str]:
736742
if isinstance(modules_to_save, str):
737743
modules_to_save = [modules_to_save]
@@ -782,6 +788,7 @@ def __post_init__(self) -> None:
782788
self.set_model_type()
783789
self.check_flash_attn()
784790
self.handle_generation_config()
791+
self.handle_lr_scheduler_kwargs()
785792
self.is_multimodal = self._is_multimodal(self.model_type)
786793

787794
self.lora_use_embedding = False

0 commit comments

Comments
 (0)