|
| 1 | +""" |
| 2 | +Hyperparameter utilities including recommended learning rate schedules. |
| 3 | +
|
| 4 | +Based on Tinker's recommended LR formula: |
| 5 | +LR(m) = lr_base × M_LoRA × (2000/H_m)^P_m |
| 6 | +
|
| 7 | +Reference: https://tinker-docs.thinkingmachines.ai/supervised-learning/sl-hyperparams |
| 8 | +""" |
| 9 | + |
| 10 | +import math |
| 11 | +from typing import Dict |
| 12 | + |
| 13 | + |
| 14 | +MODEL_HIDDEN_SIZES: Dict[str, int] = { |
| 15 | + "meta-llama/Llama-3.1-8B": 4096, |
| 16 | + "meta-llama/Llama-3.1-8B-Instruct": 4096, |
| 17 | + "meta-llama/Llama-3.1-70B": 8192, |
| 18 | + "meta-llama/Llama-3.3-70B-Instruct": 8192, |
| 19 | + "meta-llama/Llama-3.2-1B": 2048, |
| 20 | + "meta-llama/Llama-3.2-3B": 3072, |
| 21 | + "Qwen/Qwen3-8B": 4096, |
| 22 | + "Qwen/Qwen3-8B-Base": 4096, |
| 23 | + "Qwen/Qwen3-30B-A3B": 3584, |
| 24 | + "Qwen/Qwen3-30B-A3B-Base": 3584, |
| 25 | + "Qwen/Qwen3-30B-A3B-Instruct-2507": 3584, |
| 26 | + "Qwen/Qwen3-235B-A22B-Instruct-2507": 8192, |
| 27 | +} |
| 28 | + |
| 29 | + |
| 30 | +def get_recommended_lr( |
| 31 | + model_name: str, |
| 32 | + lr_base: float = 5e-5, |
| 33 | + lora_multiplier: float = 10.0, |
| 34 | +) -> float: |
| 35 | + """ |
| 36 | + Get recommended learning rate for a model using Tinker's formula. |
| 37 | +
|
| 38 | + Formula: LR(m) = lr_base × M_LoRA × (2000/H_m)^P_m |
| 39 | + where: |
| 40 | + - lr_base: Base learning rate (default 5e-5) |
| 41 | + - M_LoRA: LoRA multiplier (default 10) |
| 42 | + - H_m: Hidden size of model m |
| 43 | + - P_m: Model-specific exponent (0.0775 for Qwen, 0.781 for Llama) |
| 44 | +
|
| 45 | + Args: |
| 46 | + model_name: Full model name (e.g., "meta-llama/Llama-3.1-8B-Instruct") |
| 47 | + lr_base: Base learning rate |
| 48 | + lora_multiplier: LoRA multiplier |
| 49 | +
|
| 50 | + Returns: |
| 51 | + Recommended learning rate for the model. |
| 52 | + """ |
| 53 | + hidden_size = MODEL_HIDDEN_SIZES.get(model_name) |
| 54 | + if hidden_size is None: |
| 55 | + print(f"Warning: Unknown model {model_name}, using default LR") |
| 56 | + return lr_base * lora_multiplier |
| 57 | + |
| 58 | + if "llama" in model_name.lower(): |
| 59 | + exponent = 0.781 |
| 60 | + elif "qwen" in model_name.lower(): |
| 61 | + exponent = 0.0775 |
| 62 | + else: |
| 63 | + exponent = 0.4 |
| 64 | + |
| 65 | + lr = lr_base * lora_multiplier * math.pow(2000 / hidden_size, exponent) |
| 66 | + return lr |
| 67 | + |
| 68 | + |
| 69 | +def get_lr_with_warmup( |
| 70 | + step: int, |
| 71 | + base_lr: float, |
| 72 | + warmup_steps: int = 100, |
| 73 | + max_steps: int = 1000, |
| 74 | + min_lr: float = 1e-6, |
| 75 | +) -> float: |
| 76 | + """ |
| 77 | + Compute learning rate with linear warmup and cosine decay. |
| 78 | +
|
| 79 | + Args: |
| 80 | + step: Current training step (0-indexed). |
| 81 | + base_lr: Peak learning rate after warmup. |
| 82 | + warmup_steps: Number of warmup steps. |
| 83 | + max_steps: Total training steps. |
| 84 | + min_lr: Minimum learning rate floor. |
| 85 | +
|
| 86 | + Returns: |
| 87 | + Learning rate for the current step. |
| 88 | + """ |
| 89 | + if step < warmup_steps: |
| 90 | + return base_lr * (step + 1) / warmup_steps |
| 91 | + |
| 92 | + progress = (step - warmup_steps) / max(1, max_steps - warmup_steps) |
| 93 | + progress = min(1.0, progress) |
| 94 | + |
| 95 | + cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) |
| 96 | + lr = min_lr + (base_lr - min_lr) * cosine_decay |
| 97 | + |
| 98 | + return max(lr, min_lr) |
0 commit comments