Skip to content

Commit bd25e0c

Browse files
authored
[LLM] add decay steps option for finetuning (#8251)
1 parent 98a4b84 commit bd25e0c

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

llm/finetune_generation.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import json
1515
import os
1616
import sys
17+
from dataclasses import dataclass, field
1718
from functools import partial
1819

1920
import paddle
@@ -49,6 +50,23 @@
4950
from paddlenlp.utils.log import logger
5051

5152

53+
def add_start_docstrings(*docstr):
54+
def docstring_decorator(fn):
55+
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
56+
return fn
57+
58+
return docstring_decorator
59+
60+
61+
@dataclass
62+
@add_start_docstrings(TrainingArguments.__doc__)
63+
class FinetuneArguments(TrainingArguments):
64+
decay_steps: int = field(
65+
default=0,
66+
metadata={"help": "The steps use to control the learing rate."},
67+
)
68+
69+
5270
def read_local_dataset(path):
5371
with open(path, "r", encoding="utf-8") as fp:
5472
for line in fp:
@@ -57,7 +75,7 @@ def read_local_dataset(path):
5775

5876
def main():
5977
# Arguments
60-
parser = PdArgumentParser((GenerateArgument, QuantArgument, ModelArgument, DataArgument, TrainingArguments))
78+
parser = PdArgumentParser((GenerateArgument, QuantArgument, ModelArgument, DataArgument, FinetuneArguments))
6179
# Support format as "args.json --arg1 value1 --arg2 value2.”
6280
# In case of conflict, command line arguments take precedence.
6381
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):

paddlenlp/trainer/trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1650,13 +1650,16 @@ def create_scheduler(self, num_training_steps: int):
16501650
warmup = (
16511651
self.args.warmup_steps if self.args.warmup_steps > 0 else int(self.args.warmup_ratio * num_training_steps)
16521652
)
1653+
decay_steps = num_training_steps
1654+
if hasattr(self.args, "decay_steps") and self.args.decay_steps > 0:
1655+
decay_steps = self.args.decay_steps
16531656

16541657
if self.lr_scheduler is None:
16551658
self.lr_scheduler = get_scheduler(
16561659
self.args.lr_scheduler_type,
16571660
learning_rate=self.args.learning_rate,
16581661
num_warmup_steps=warmup,
1659-
num_training_steps=num_training_steps,
1662+
num_training_steps=decay_steps,
16601663
num_cycles=self.args.num_cycles,
16611664
lr_end=self.args.lr_end,
16621665
power=self.args.power,

0 commit comments

Comments
 (0)