Skip to content

Commit 292587d

Browse files
committed
Add Tinker recommended LR formula and warmup/cosine scheduler
1 parent d7d0f8b commit 292587d

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed

hyperparam_utils.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""
2+
Hyperparameter utilities including recommended learning rate schedules.
3+
4+
Based on Tinker's recommended LR formula:
5+
LR(m) = lr_base × M_LoRA × (2000/H_m)^P_m
6+
7+
Reference: https://tinker-docs.thinkingmachines.ai/supervised-learning/sl-hyperparams
8+
"""
9+
10+
import math
11+
from typing import Dict
12+
13+
14+
MODEL_HIDDEN_SIZES: Dict[str, int] = {
15+
"meta-llama/Llama-3.1-8B": 4096,
16+
"meta-llama/Llama-3.1-8B-Instruct": 4096,
17+
"meta-llama/Llama-3.1-70B": 8192,
18+
"meta-llama/Llama-3.3-70B-Instruct": 8192,
19+
"meta-llama/Llama-3.2-1B": 2048,
20+
"meta-llama/Llama-3.2-3B": 3072,
21+
"Qwen/Qwen3-8B": 4096,
22+
"Qwen/Qwen3-8B-Base": 4096,
23+
"Qwen/Qwen3-30B-A3B": 3584,
24+
"Qwen/Qwen3-30B-A3B-Base": 3584,
25+
"Qwen/Qwen3-30B-A3B-Instruct-2507": 3584,
26+
"Qwen/Qwen3-235B-A22B-Instruct-2507": 8192,
27+
}
28+
29+
30+
def get_recommended_lr(
31+
model_name: str,
32+
lr_base: float = 5e-5,
33+
lora_multiplier: float = 10.0,
34+
) -> float:
35+
"""
36+
Get recommended learning rate for a model using Tinker's formula.
37+
38+
Formula: LR(m) = lr_base × M_LoRA × (2000/H_m)^P_m
39+
where:
40+
- lr_base: Base learning rate (default 5e-5)
41+
- M_LoRA: LoRA multiplier (default 10)
42+
- H_m: Hidden size of model m
43+
- P_m: Model-specific exponent (0.0775 for Qwen, 0.781 for Llama)
44+
45+
Args:
46+
model_name: Full model name (e.g., "meta-llama/Llama-3.1-8B-Instruct")
47+
lr_base: Base learning rate
48+
lora_multiplier: LoRA multiplier
49+
50+
Returns:
51+
Recommended learning rate for the model.
52+
"""
53+
hidden_size = MODEL_HIDDEN_SIZES.get(model_name)
54+
if hidden_size is None:
55+
print(f"Warning: Unknown model {model_name}, using default LR")
56+
return lr_base * lora_multiplier
57+
58+
if "llama" in model_name.lower():
59+
exponent = 0.781
60+
elif "qwen" in model_name.lower():
61+
exponent = 0.0775
62+
else:
63+
exponent = 0.4
64+
65+
lr = lr_base * lora_multiplier * math.pow(2000 / hidden_size, exponent)
66+
return lr
67+
68+
69+
def get_lr_with_warmup(
70+
step: int,
71+
base_lr: float,
72+
warmup_steps: int = 100,
73+
max_steps: int = 1000,
74+
min_lr: float = 1e-6,
75+
) -> float:
76+
"""
77+
Compute learning rate with linear warmup and cosine decay.
78+
79+
Args:
80+
step: Current training step (0-indexed).
81+
base_lr: Peak learning rate after warmup.
82+
warmup_steps: Number of warmup steps.
83+
max_steps: Total training steps.
84+
min_lr: Minimum learning rate floor.
85+
86+
Returns:
87+
Learning rate for the current step.
88+
"""
89+
if step < warmup_steps:
90+
return base_lr * (step + 1) / warmup_steps
91+
92+
progress = (step - warmup_steps) / max(1, max_steps - warmup_steps)
93+
progress = min(1.0, progress)
94+
95+
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
96+
lr = min_lr + (base_lr - min_lr) * cosine_decay
97+
98+
return max(lr, min_lr)

0 commit comments

Comments
 (0)