|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | | -"""Trainer protocol. |
| 7 | +"""Trainer protocol for Forge. |
| 8 | +
|
| 9 | +This module defines the unified training interface that all trainer implementations |
| 10 | +must conform to. |
8 | 11 |
|
9 | | -This file defines the unified training interface compatible |
10 | | -with all supported torchforge trainers. |
11 | 12 | """ |
12 | 13 |
|
13 | 14 | from typing import Any, Protocol, runtime_checkable |
14 | 15 |
|
15 | 16 | import torch |
16 | 17 |
|
| 18 | +from forge.api.types import ( |
| 19 | + ForwardResult, |
| 20 | + OptimStepResult, |
| 21 | + TextTrainBatch, |
| 22 | + TrainerInfo, |
| 23 | + TrainerStatus, |
| 24 | + TrainResult, |
| 25 | +) |
| 26 | + |
17 | 27 |
|
18 | 28 | @runtime_checkable |
19 | 29 | class Trainer(Protocol): |
20 | | - """Protocol for all trainers in torchforge.""" |
| 30 | + """Protocol defining the standard interface for all Forge trainers.""" |
21 | 31 |
|
22 | | - async def accumulate_gradients( |
23 | | - self, microbatch: dict[str, torch.Tensor] |
24 | | - ) -> dict[str, Any]: |
25 | | - """Accumulate gradients from one microbatch. |
| 32 | + async def forward_backward(self, batch: TextTrainBatch) -> TrainResult: |
| 33 | + """Execute forward pass and backward pass for one batch of data. |
26 | 34 |
|
27 | | - Does NOT clear gradients - they accumulate on top of existing. |
28 | | - Can be called multiple times before optim_step(). |
| 35 | + Basic usage - single batch per optimizer step: |
| 36 | + >>> batch = TextTrainBatch( |
| 37 | + >>> input_ids=torch.tensor([[1, 2, 3, 4, 5]]), |
| 38 | + >>> target_ids=torch.tensor([[2, 3, 4, 5, 6]]), |
| 39 | + >>> ) |
| 40 | + >>> result = await trainer.forward_backward(batch) |
| 41 | + >>> await trainer.optim_step() # Apply gradients |
| 42 | +
|
| 43 | + To accumulate gradients over multiple batches before optimizer step: |
| 44 | + >>> await trainer.forward_backward(batch1) # Accumulates |
| 45 | + >>> await trainer.forward_backward(batch2) # Accumulates another batch |
| 46 | + >>> await trainer.optim_step() # Apply all accumulated gradients |
| 47 | +
|
| 48 | + Args: |
| 49 | + batch: TextTrainBatch containing input_ids, target_ids, and optional |
| 50 | + target_mask/target_weights. See forge.api.types.TextTrainBatch for details. |
29 | 51 |
|
30 | 52 | Returns: |
31 | | - dict with keys: |
32 | | - - loss: float |
33 | | - - metrics: dict[str, float] |
| 53 | + TrainResult containing loss and metrics |
| 54 | +
|
| 55 | + Note: |
| 56 | + The loss function is configured at trainer creation time via the |
| 57 | + `loss` parameter, not passed to this method. |
34 | 58 | """ |
35 | 59 | ... |
36 | 60 |
|
37 | | - async def optim_step(self, params: dict[str, Any] | None = None) -> dict[str, Any]: |
38 | | - """Apply optimizer step and clear gradients after. |
| 61 | + async def optim_step(self, params: dict[str, Any] | None = None) -> OptimStepResult: |
| 62 | + """Apply optimizer step using accumulated gradients, then clear gradients. |
| 63 | +
|
| 64 | + This method: |
| 65 | + 1. Applies accumulated gradients via the optimizer |
| 66 | + 2. Steps the learning rate scheduler |
| 67 | + 3. Clears all gradients (zero_grad) |
| 68 | + 4. Increments the training step counter |
| 69 | + 5. May trigger automatic checkpointing (implementation-dependent) |
| 70 | +
|
| 71 | + Gradients must have been accumulated via forward_backward() calls before |
| 72 | + calling this method. |
| 73 | +
|
| 74 | + Args: |
| 75 | + params: Optional optimizer parameters. Currently reserved for future use. |
| 76 | + Most implementations ignore this and use the optimizer config from |
| 77 | + trainer initialization. |
39 | 78 |
|
40 | 79 | Returns: |
41 | | - dict with keys: |
42 | | - - step: int |
43 | | - - learning_rate: float |
44 | | - - accumulated_microbatches: int |
| 80 | + OptimStepResult containing step number, learning rate, and accumulated batch count |
| 81 | +
|
| 82 | + Example: |
| 83 | + >>> # Accumulate over 4 batches |
| 84 | + >>> for batch in batches[:4]: |
| 85 | + >>> await trainer.forward_backward(batch) |
| 86 | + >>> result = await trainer.optim_step() |
| 87 | + >>> print(f"Step {result.step}, LR {result.learning_rate:.2e}") |
| 88 | + >>> print(f"Accumulated {result.accumulated_microbatches} batches") |
45 | 89 | """ |
46 | 90 | ... |
47 | 91 |
|
48 | 92 | async def clear_gradients(self) -> None: |
49 | | - """Clear accumulated gradients without applying.""" |
| 93 | + """Clear accumulated gradients without applying them. |
| 94 | +
|
| 95 | + Use this when you need to discard accumulated gradients without performing |
| 96 | + an optimizer step. Common scenarios: |
| 97 | + - Exception during gradient accumulation |
| 98 | + - Skipping a training step due to some condition |
| 99 | + - Recovering from OOM or other errors |
| 100 | +
|
| 101 | + This is equivalent to calling optimizer.zero_grad() and resetting internal |
| 102 | + accumulation counters. |
| 103 | +
|
| 104 | + Example - Error recovery: |
| 105 | + >>> try: |
| 106 | + >>> for batch in batches: |
| 107 | + >>> await trainer.forward_backward(batch) |
| 108 | + >>> await trainer.optim_step() |
| 109 | + >>> except torch.cuda.OutOfMemoryError: |
| 110 | + >>> await trainer.clear_gradients() # Discard partial gradients |
| 111 | + >>> # Retry with smaller batches |
| 112 | +
|
| 113 | + Example - Conditional skip: |
| 114 | + >>> await trainer.forward_backward(batch) |
| 115 | + >>> if should_skip_step(): |
| 116 | + >>> await trainer.clear_gradients() # Don't apply these gradients |
| 117 | + >>> else: |
| 118 | + >>> await trainer.optim_step() |
| 119 | + """ |
50 | 120 | ... |
51 | 121 |
|
52 | | - async def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
53 | | - """Run forward pass, no backward. |
| 122 | + async def forward(self, inputs: dict[str, torch.Tensor]) -> ForwardResult: |
| 123 | + """Run forward pass only, without backward pass (for evaluation/inference). |
| 124 | +
|
| 125 | + This method executes the model's forward pass without computing gradients. |
| 126 | + Useful for: |
| 127 | + - Evaluation on validation/test data |
| 128 | + - Getting model predictions/logits |
| 129 | + - Debugging model outputs |
| 130 | +
|
| 131 | + Args: |
| 132 | + inputs: Dictionary containing model inputs. Typically includes: |
| 133 | + - input_ids: torch.Tensor [batch_size, seq_len] |
| 134 | + Other keys depend on the model architecture. |
54 | 135 |
|
55 | 136 | Returns: |
56 | | - dict with key: |
57 | | - - logits: torch.Tensor |
| 137 | + ForwardResult containing model logits |
| 138 | +
|
| 139 | + Note: |
| 140 | + This runs in torch.no_grad() context - no gradients are computed. |
| 141 | +
|
| 142 | + Example: |
| 143 | + >>> eval_batch = {"input_ids": torch.tensor([[1, 2, 3, 4]])} |
| 144 | + >>> output = await trainer.forward(eval_batch) |
| 145 | + >>> logits = output.logits # [1, 4, vocab_size] |
| 146 | + >>> predictions = logits.argmax(dim=-1) # [1, 4] |
58 | 147 | """ |
59 | 148 | ... |
60 | 149 |
|
61 | | - async def forward_backward( |
62 | | - self, data: list[dict[str, torch.Tensor]] |
| 150 | + async def save_state( |
| 151 | + self, name: str | None = None, path: str | None = None |
63 | 152 | ) -> dict[str, Any]: |
64 | | - """Clear first, then forward+backward on all items in data. |
| 153 | + """Save a checkpoint of the current trainer state. |
65 | 154 |
|
66 | | - Convenience wrapper equivalent to: |
67 | | - clear_gradients() + accumulate_gradients() for each item |
| 155 | + Saves the complete training state including model weights, optimizer state, |
| 156 | + learning rate scheduler state, and current step counter. This checkpoint |
| 157 | + can be loaded later to resume training from this exact point. |
68 | 158 |
|
69 | | - Does NOT call optim_step() - you must call it separately. |
| 159 | + Args: |
| 160 | + name: Optional checkpoint name/identifier. If None, uses the current |
| 161 | + step number (e.g., "step-1000"). |
| 162 | + path: Optional base directory or URI where checkpoint should be saved. |
| 163 | + If None, uses the default checkpoint directory configured at trainer |
| 164 | + creation. Supports different backends via URI schemes: |
| 165 | + - `/local/path` - local filesystem |
| 166 | + - `ts://key` - TorchStore |
| 167 | + - `s3://bucket/key` - S3 |
| 168 | +
|
| 169 | + Location resolution: |
| 170 | + - Both provided: path/name (e.g., "/checkpoints" + "best" = "/checkpoints/best") |
| 171 | + - Only path: use path directly |
| 172 | + - Only name: default_dir/name |
| 173 | + - Neither: default_dir/step-{current_step} |
70 | 174 |
|
71 | 175 | Returns: |
72 | | - dict with keys: |
73 | | - - loss: float |
74 | | - - metrics: dict[str, float] |
| 176 | + dict containing: |
| 177 | + - path: str - Full path where checkpoint was saved |
| 178 | + - step: int - Training step at which checkpoint was saved |
| 179 | +
|
| 180 | + Example: |
| 181 | + >>> # Save to default location with step number |
| 182 | + >>> result = await trainer.save_state() # => /default/step-1000 |
| 183 | + >>> |
| 184 | + >>> # Save with custom name to default location |
| 185 | + >>> result = await trainer.save_state("best-model") # => /default/best-model |
| 186 | + >>> |
| 187 | + >>> # Save to custom base directory |
| 188 | + >>> result = await trainer.save_state("final", "/custom/checkpoints") |
| 189 | + >>> # => /custom/checkpoints/final |
75 | 190 | """ |
76 | 191 | ... |
77 | 192 |
|
78 | | - async def save_state(self, name: str) -> dict[str, Any]: |
79 | | - """Save the checkpoint. |
| 193 | + async def load_state(self, path: str | None = None) -> dict[str, Any]: |
| 194 | + """Load a previously saved checkpoint. |
| 195 | +
|
| 196 | + Restores the complete training state from a checkpoint, including model |
| 197 | + weights, optimizer state, learning rate scheduler state, and step counter. |
| 198 | +
|
| 199 | + Args: |
| 200 | + path: Optional path or URI to the checkpoint to load. If None, loads |
| 201 | + the most recent checkpoint from the default directory. Can be: |
| 202 | + - `/local/path/checkpoint` - local filesystem |
| 203 | + - `ts://key` - TorchStore |
| 204 | + - `s3://bucket/key` - S3 |
| 205 | +
|
| 206 | + Returns: |
| 207 | + dict containing: |
| 208 | + - step: int - Training step from the loaded checkpoint |
| 209 | + - learning_rate: float - Learning rate from the loaded checkpoint |
| 210 | +
|
| 211 | + Example: |
| 212 | + >>> # Load latest checkpoint from default location |
| 213 | + >>> result = await trainer.load_state() |
| 214 | + >>> print(f"Resumed from step {result['step']}") |
| 215 | + >>> |
| 216 | + >>> # Load specific checkpoint by path |
| 217 | + >>> result = await trainer.load_state("/checkpoints/step-5000") |
| 218 | + >>> |
| 219 | + >>> # Load from TorchStore |
| 220 | + >>> result = await trainer.load_state("ts://checkpoint-key") |
| 221 | + """ |
| 222 | + ... |
| 223 | + |
| 224 | + async def save_weights( |
| 225 | + self, name: str | None = None, path: str | None = None |
| 226 | + ) -> dict[str, Any]: |
| 227 | + """Save model weights only (without optimizer/scheduler state). |
| 228 | +
|
| 229 | + Saves only the model weights in a format suitable for inference/sampling. |
| 230 | + This is lighter weight than save_state() since it excludes training state |
| 231 | + like optimizer and scheduler. |
| 232 | +
|
| 233 | + Args: |
| 234 | + name: Optional checkpoint name/identifier. If None, uses the current |
| 235 | + step number (e.g., "weights-step-1000"). |
| 236 | + path: Optional base directory or URI where weights should be saved. |
| 237 | + If None, uses the default location configured at trainer creation. |
| 238 | + Supports different backends via URI schemes: |
| 239 | + - `/local/path` - local filesystem |
| 240 | + - `ts://key` - TorchStore |
| 241 | + - `s3://bucket/key` - S3 |
| 242 | +
|
| 243 | + Location resolution: |
| 244 | + - Both provided: path/name |
| 245 | + - Only path: use path directly |
| 246 | + - Only name: default_dir/name |
| 247 | + - Neither: default_dir/step-{current_step} |
80 | 248 |
|
81 | 249 | Returns: |
82 | | - dict with keys: |
83 | | - - path: str |
84 | | - - step: int |
| 250 | + dict containing: |
| 251 | + - path: str - Full URI where weights were saved |
| 252 | + - version: str | int - The name/version that was saved |
| 253 | +
|
| 254 | + Example: |
| 255 | + >>> # Save to default location with step number |
| 256 | + >>> result = await trainer.save_weights() |
| 257 | + >>> |
| 258 | + >>> # Save to TorchStore for inference server |
| 259 | + >>> result = await trainer.save_weights("policy-v1", "ts://policy-weights") |
| 260 | + >>> # → ts://policy-weights/policy-v1 |
| 261 | + >>> |
| 262 | + >>> # Save to S3 |
| 263 | + >>> result = await trainer.save_weights(path="s3://bucket/models/final") |
85 | 264 | """ |
86 | 265 | ... |
87 | 266 |
|
88 | | - async def load_state(self, path: str) -> dict[str, Any]: |
89 | | - """Load checkpoint. |
| 267 | + async def get_info(self) -> TrainerInfo: |
| 268 | + """Get static trainer and model metadata. |
| 269 | +
|
| 270 | + Returns information about the trainer configuration and model architecture |
| 271 | + that doesn't change during training. |
90 | 272 |
|
91 | 273 | Returns: |
92 | | - dict with keys: |
93 | | - - step: int |
94 | | - - learning_rate: float |
| 274 | + TrainerInfo containing model name, step, config, and parallelism settings |
| 275 | +
|
| 276 | + Example: |
| 277 | + >>> info = await trainer.get_info() |
| 278 | + >>> print(f"Training {info.model_name} at step {info.step}") |
| 279 | + >>> print(f"Vocab size: {info.config['vocab_size']}") |
| 280 | + >>> print(f"Data parallel degree: {info.parallelism['dp_degree']}") |
95 | 281 | """ |
96 | 282 | ... |
97 | 283 |
|
98 | | - async def save_weights_for_sampler(self, name: str) -> dict[str, Any]: |
99 | | - """Export weights for inference. |
| 284 | + async def get_status(self) -> TrainerStatus: |
| 285 | + """Get current runtime status of the trainer. |
| 286 | +
|
| 287 | + Returns dynamic information about the trainer's current state that changes |
| 288 | + during training. |
100 | 289 |
|
101 | 290 | Returns: |
102 | | - dict with keys: |
103 | | - - path: str |
104 | | - - version: str or int |
| 291 | + TrainerStatus containing current step and accumulated batch count |
| 292 | +
|
| 293 | + Example: |
| 294 | + >>> status = await trainer.get_status() |
| 295 | + >>> print(f"Current step: {status.step}") |
| 296 | + >>> if status.accumulated_microbatches > 0: |
| 297 | + >>> print(f"Warning: {status.accumulated_microbatches} " |
| 298 | + >>> f"batches accumulated without optimizer step") |
105 | 299 | """ |
106 | 300 | ... |
107 | 301 |
|
108 | 302 | def get_tokenizer(self): |
109 | | - """Get the tokenizer. |
| 303 | + """Get the tokenizer associated with this model. |
| 304 | +
|
| 305 | + Returns the tokenizer used for encoding/decoding text with this model. |
| 306 | + Useful for preprocessing inputs or decoding model outputs. |
110 | 307 |
|
111 | 308 | Returns: |
112 | | - PreTrainedTokenizer |
| 309 | + PreTrainedTokenizer: The HuggingFace tokenizer for this model |
| 310 | +
|
| 311 | + Note: |
| 312 | + This is a synchronous method (not async) since tokenizer access is |
| 313 | + typically fast and doesn't require remote calls. |
| 314 | +
|
| 315 | + Example: |
| 316 | + >>> tokenizer = trainer.get_tokenizer() |
| 317 | + >>> tokens = tokenizer.encode("Hello world") |
| 318 | + >>> text = tokenizer.decode([1, 2, 3, 4]) |
113 | 319 | """ |
114 | 320 | ... |
0 commit comments