Skip to content

Commit d67a913

Browse files
committed
Added LR scheduler functionality to Tinker Trainer.
1 parent ae965d8 commit d67a913

File tree

3 files changed

+97
-10
lines changed

3 files changed

+97
-10
lines changed

tests/trainer/trainer_test.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import json
5+
import math
56
import multiprocessing
67
import os
78
import shutil
@@ -48,6 +49,8 @@
4849
from trinity.common.models.utils import get_checkpoint_dir_with_step_num
4950
from trinity.explorer.proxy.client import TrinityClient
5051
from trinity.manager.state_manager import StateManager
52+
from trinity.manager.synchronizer import Synchronizer
53+
from trinity.trainer.tinker_trainer import TinkerTrainerWrapper
5154

5255

5356
class BaseTrainerCase(RayUnittestBase):
@@ -1448,7 +1451,7 @@ def test_trainer(self):
14481451
self.config.buffer.total_epochs = 1
14491452
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
14501453
self.config.model.tinker.enable = True
1451-
self.config.model.tinker.base_model = "Qwen/Qwen3-4B-Instruct-2507"
1454+
self.config.model.model_path = "Qwen/Qwen3-4B-Instruct-2507"
14521455
self.config.check_and_update()
14531456
both(self.config)
14541457
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
@@ -1464,6 +1467,40 @@ def test_trainer(self):
14641467
self.assertGreater(len(response_metrics), 0)
14651468
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
14661469

1470+
def test_trainer_class(self):
1471+
total_steps = 100
1472+
lr_warmup_steps = 10
1473+
self.config.algorithm.algorithm_type = "grpo"
1474+
self.config.model.tinker.enable = True
1475+
self.config.model.model_path = "Qwen/Qwen3-4B-Instruct-2507"
1476+
self.config.trainer.total_steps = total_steps
1477+
self.config.algorithm.optimizer.lr_warmup_steps = lr_warmup_steps
1478+
self.config.algorithm.optimizer.lr_scheduler_type = "cosine"
1479+
self.config.check_and_update()
1480+
lr = self.config.algorithm.optimizer.lr
1481+
1482+
@ray.remote
1483+
class FakeExplorer:
1484+
def __init__(self, config: Config):
1485+
self.config = config
1486+
self.synchronizer = Synchronizer.get_actor(config)
1487+
1488+
fake_explorer = FakeExplorer.remote(self.config)
1489+
ray.get(fake_explorer.__ray_ready__.remote())
1490+
1491+
tinker_trainer = TinkerTrainerWrapper(self.config)
1492+
tinker_trainer._train_step_num = 5
1493+
self.assertEqual(tinker_trainer.current_learning_rate, lr * 0.5)
1494+
tinker_trainer._train_step_num = 50
1495+
self.assertEqual(
1496+
tinker_trainer.current_learning_rate,
1497+
lr
1498+
* (
1499+
0.5
1500+
* (1 + math.cos((50 - lr_warmup_steps) / (total_steps - lr_warmup_steps) * math.pi))
1501+
),
1502+
)
1503+
14671504
def tearDown(self):
14681505
# remove dir only when the test passed
14691506
shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True)

trinity/common/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ class OptimizerConfig:
100100
betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
101101
weight_decay: float = 0.01
102102
clip_grad: float = 1.0
103-
lr_warmup_init: float = 0.0
104-
lr_decay_steps: Optional[int] = None
105-
lr_decay_style: str = "constant" # duplicated with lr_scheduler_type in veRL
103+
lr_warmup_init: float = 0.0 # used in megatron
104+
lr_decay_steps: Optional[int] = None # used in megatron
105+
lr_decay_style: str = "constant" # used in megatron, duplicated with lr_scheduler_type in veRL
106106
min_lr: float = 0.0
107107

108108

trinity/trainer/tinker_trainer.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import math
12
import os
3+
import sys
24
from typing import Dict, List
35

46
import ray
@@ -36,7 +38,7 @@ def __init__(self, config: Config):
3638

3739
def _init_algorithm(self):
3840
self.algorithm = ALGORITHM_TYPE.get(self.config.algorithm.algorithm_type)
39-
algorithm_config = self.config.algorithm
41+
self.algorithm_config = algorithm_config = self.config.algorithm
4042
if self.algorithm.compute_advantage_in_trainer:
4143
self.advantage_fn = ADVANTAGE_FN.get(algorithm_config.advantage_fn)(
4244
**algorithm_config.advantage_fn_args
@@ -63,12 +65,60 @@ def _init_algorithm(self):
6365
and (self.loss_agg_mode == "token-mean")
6466
)
6567

66-
self.adam_params = types.AdamParams(
67-
learning_rate=algorithm_config.optimizer.lr,
68-
beta1=algorithm_config.optimizer.betas[0],
69-
beta2=algorithm_config.optimizer.betas[1],
68+
self.lr_scheduler_type = algorithm_config.optimizer.lr_scheduler_type
69+
self.total_steps = self.config.trainer.total_steps or sys.maxsize
70+
self.num_warmup_steps = algorithm_config.optimizer.lr_warmup_steps
71+
if self.num_warmup_steps < 0:
72+
self.num_warmup_steps = int(
73+
algorithm_config.optimizer.lr_warmup_steps_ratio * self.total_steps
74+
)
75+
self.min_lr_ratio = algorithm_config.optimizer.min_lr_ratio
76+
assert 0.0 <= self.min_lr_ratio <= 1.0
77+
self.logger.info(
78+
f"Total steps: {self.total_steps}, num_warmup_steps: {self.num_warmup_steps}"
79+
)
80+
81+
if self.lr_scheduler_type not in {"constant", "cosine"}:
82+
raise NotImplementedError(
83+
f"LR scheduler type {self.lr_scheduler_type} is not supported"
84+
)
85+
86+
@property
87+
def _current_lr_factor(self):
88+
train_step_num = self._train_step_num
89+
# warmup
90+
if train_step_num < self.num_warmup_steps:
91+
factor = float(train_step_num) / float(max(1.0, self.num_warmup_steps))
92+
factor = self.min_lr_ratio + (1.0 - self.min_lr_ratio) * factor
93+
return factor
94+
95+
# decay
96+
if train_step_num >= self.total_steps:
97+
progress = 1.0
98+
else:
99+
progress = float(train_step_num - self.num_warmup_steps) / float(
100+
max(1.0, self.total_steps - self.num_warmup_steps)
101+
)
102+
if self.lr_scheduler_type == "constant":
103+
factor = 1.0
104+
elif self.lr_scheduler_type == "cosine":
105+
num_cycles = 0.5 # TODO: may add to config
106+
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
107+
factor = self.min_lr_ratio + (1.0 - self.min_lr_ratio) * factor
108+
return max(self.min_lr_ratio, factor)
109+
110+
@property
111+
def current_learning_rate(self):
112+
return self._current_lr_factor * self.algorithm_config.optimizer.lr
113+
114+
@property
115+
def adam_params(self):
116+
return types.AdamParams(
117+
learning_rate=self.current_learning_rate,
118+
beta1=self.algorithm_config.optimizer.betas[0],
119+
beta2=self.algorithm_config.optimizer.betas[1],
70120
# eps is currently not in config
71-
weight_decay=algorithm_config.optimizer.weight_decay,
121+
weight_decay=self.algorithm_config.optimizer.weight_decay,
72122
grad_clip_norm=self.config.trainer.grad_clip,
73123
)
74124

0 commit comments

Comments
 (0)