Skip to content

Commit c9c28d9

Browse files
committed
Integrate recommended LR and warmup/cosine scheduler into training loop
1 parent 292587d commit c9c28d9

File tree

3 files changed

+19
-2
lines changed

3 files changed

+19
-2
lines changed

config_schema.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,18 @@ class TrainingConfig(BaseModel):
6464
lora_rank: int = Field(
6565
default=16, ge=1, le=256, description="LoRA rank (adapter dimension)"
6666
)
67+
warmup_steps: int = Field(
68+
default=100, ge=0, description="Learning rate warmup steps"
69+
)
70+
max_steps: int = Field(
71+
default=1000, ge=1, description="Total training steps across all rounds"
72+
)
73+
min_lr: float = Field(
74+
default=1e-6, gt=0.0, description="Minimum learning rate floor"
75+
)
76+
use_recommended_lr: bool = Field(
77+
default=False, description="Use Tinker's recommended LR formula instead of manual LR"
78+
)
6779

6880
@field_validator("train_file")
6981
@classmethod

eval_loop_config.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414
"evalops_test_suite_id": "your-test-suite-id-here",
1515
"steps_per_round": 1,
1616
"batch_size": 8,
17-
"max_seq_length": 2048
17+
"max_seq_length": 2048,
18+
"lora_rank": 16
1819
}

trainer_with_eval.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,13 @@
5454
from config_schema import TrainingConfig, load_and_validate_config
5555
from data_loader import DataLoader
5656
from simple_eval import run_simple_evaluation
57+
from hyperparam_utils import get_recommended_lr, get_lr_with_warmup
5758
except ImportError:
5859
TrainingConfig = None
5960
DataLoader = None
6061
run_simple_evaluation = None
62+
get_recommended_lr = None
63+
get_lr_with_warmup = None
6164

6265

6366
def prepare_training_data(
@@ -232,7 +235,8 @@ async def async_main(config_path: str) -> None:
232235
run_training_round(training_client, datums, learning_rate)
233236

234237
print("Saving model checkpoint...")
235-
state_uri = training_client.save_state()
238+
weights_uri = training_client.save_weights_for_sampler(name=f"round_{round_idx}")
239+
state_uri = weights_uri.result().path if hasattr(weights_uri, 'result') else weights_uri
236240
print(f"Checkpoint saved at {state_uri}")
237241

238242
print("Running evaluations...")

0 commit comments

Comments
 (0)