Skip to content

Commit d89bc1b

Browse files
author
Allen Wang
committed
bulk changes
1 parent 189b242 commit d89bc1b

File tree

3 files changed

+452
-51
lines changed

3 files changed

+452
-51
lines changed

src/forge/api/__init__.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,27 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# Forge library modules
7+
"""Forge public API module.
8+
9+
This module defines the public interfaces that all Forge implementations conform to.
10+
"""
11+
12+
from forge.api.trainer import Trainer
13+
from forge.api.types import (
14+
ForwardResult,
15+
OptimStepResult,
16+
TextTrainBatch,
17+
TrainerInfo,
18+
TrainerStatus,
19+
TrainResult,
20+
)
21+
22+
__all__ = [
23+
"Trainer",
24+
"TextTrainBatch",
25+
"TrainResult",
26+
"OptimStepResult",
27+
"ForwardResult",
28+
"TrainerInfo",
29+
"TrainerStatus",
30+
]

src/forge/api/trainer.py

Lines changed: 256 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -4,111 +4,317 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
"""Trainer protocol.
7+
"""Trainer protocol for Forge.
8+
9+
This module defines the unified training interface that all trainer implementations
10+
must conform to.
811
9-
This file defines the unified training interface compatible
10-
with all supported torchforge trainers.
1112
"""
1213

1314
from typing import Any, Protocol, runtime_checkable
1415

1516
import torch
1617

18+
from forge.api.types import (
19+
ForwardResult,
20+
OptimStepResult,
21+
TextTrainBatch,
22+
TrainerInfo,
23+
TrainerStatus,
24+
TrainResult,
25+
)
26+
1727

1828
@runtime_checkable
1929
class Trainer(Protocol):
20-
"""Protocol for all trainers in torchforge."""
30+
"""Protocol defining the standard interface for all Forge trainers."""
2131

22-
async def accumulate_gradients(
23-
self, microbatch: dict[str, torch.Tensor]
24-
) -> dict[str, Any]:
25-
"""Accumulate gradients from one microbatch.
32+
async def forward_backward(self, batch: TextTrainBatch) -> TrainResult:
33+
"""Execute forward pass and backward pass for one batch of data.
2634
27-
Does NOT clear gradients - they accumulate on top of existing.
28-
Can be called multiple times before optim_step().
35+
Basic usage - single batch per optimizer step:
36+
>>> batch = TextTrainBatch(
37+
>>> input_ids=torch.tensor([[1, 2, 3, 4, 5]]),
38+
>>> target_ids=torch.tensor([[2, 3, 4, 5, 6]]),
39+
>>> )
40+
>>> result = await trainer.forward_backward(batch)
41+
>>> await trainer.optim_step() # Apply gradients
42+
43+
To accumulate gradients over multiple batches before optimizer step:
44+
>>> await trainer.forward_backward(batch1) # Accumulates
45+
>>> await trainer.forward_backward(batch2) # Accumulates another batch
46+
>>> await trainer.optim_step() # Apply all accumulated gradients
47+
48+
Args:
49+
batch: TextTrainBatch containing input_ids, target_ids, and optional
50+
target_mask/target_weights. See forge.api.types.TextTrainBatch for details.
2951
3052
Returns:
31-
dict with keys:
32-
- loss: float
33-
- metrics: dict[str, float]
53+
TrainResult containing loss and metrics
54+
55+
Note:
56+
The loss function is configured at trainer creation time via the
57+
`loss` parameter, not passed to this method.
3458
"""
3559
...
3660

37-
async def optim_step(self, params: dict[str, Any] | None = None) -> dict[str, Any]:
38-
"""Apply optimizer step and clear gradients after.
61+
async def optim_step(self, params: dict[str, Any] | None = None) -> OptimStepResult:
62+
"""Apply optimizer step using accumulated gradients, then clear gradients.
63+
64+
This method:
65+
1. Applies accumulated gradients via the optimizer
66+
2. Steps the learning rate scheduler
67+
3. Clears all gradients (zero_grad)
68+
4. Increments the training step counter
69+
5. May trigger automatic checkpointing (implementation-dependent)
70+
71+
Gradients must have been accumulated via forward_backward() calls before
72+
calling this method.
73+
74+
Args:
75+
params: Optional optimizer parameters. Currently reserved for future use.
76+
Most implementations ignore this and use the optimizer config from
77+
trainer initialization.
3978
4079
Returns:
41-
dict with keys:
42-
- step: int
43-
- learning_rate: float
44-
- accumulated_microbatches: int
80+
OptimStepResult containing step number, learning rate, and accumulated batch count
81+
82+
Example:
83+
>>> # Accumulate over 4 batches
84+
>>> for batch in batches[:4]:
85+
>>> await trainer.forward_backward(batch)
86+
>>> result = await trainer.optim_step()
87+
>>> print(f"Step {result.step}, LR {result.learning_rate:.2e}")
88+
>>> print(f"Accumulated {result.accumulated_microbatches} batches")
4589
"""
4690
...
4791

4892
async def clear_gradients(self) -> None:
49-
"""Clear accumulated gradients without applying."""
93+
"""Clear accumulated gradients without applying them.
94+
95+
Use this when you need to discard accumulated gradients without performing
96+
an optimizer step. Common scenarios:
97+
- Exception during gradient accumulation
98+
- Skipping a training step due to some condition
99+
- Recovering from OOM or other errors
100+
101+
This is equivalent to calling optimizer.zero_grad() and resetting internal
102+
accumulation counters.
103+
104+
Example - Error recovery:
105+
>>> try:
106+
>>> for batch in batches:
107+
>>> await trainer.forward_backward(batch)
108+
>>> await trainer.optim_step()
109+
>>> except torch.cuda.OutOfMemoryError:
110+
>>> await trainer.clear_gradients() # Discard partial gradients
111+
>>> # Retry with smaller batches
112+
113+
Example - Conditional skip:
114+
>>> await trainer.forward_backward(batch)
115+
>>> if should_skip_step():
116+
>>> await trainer.clear_gradients() # Don't apply these gradients
117+
>>> else:
118+
>>> await trainer.optim_step()
119+
"""
50120
...
51121

52-
async def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
53-
"""Run forward pass, no backward.
122+
async def forward(self, inputs: dict[str, torch.Tensor]) -> ForwardResult:
123+
"""Run forward pass only, without backward pass (for evaluation/inference).
124+
125+
This method executes the model's forward pass without computing gradients.
126+
Useful for:
127+
- Evaluation on validation/test data
128+
- Getting model predictions/logits
129+
- Debugging model outputs
130+
131+
Args:
132+
inputs: Dictionary containing model inputs. Typically includes:
133+
- input_ids: torch.Tensor [batch_size, seq_len]
134+
Other keys depend on the model architecture.
54135
55136
Returns:
56-
dict with key:
57-
- logits: torch.Tensor
137+
ForwardResult containing model logits
138+
139+
Note:
140+
This runs in torch.no_grad() context - no gradients are computed.
141+
142+
Example:
143+
>>> eval_batch = {"input_ids": torch.tensor([[1, 2, 3, 4]])}
144+
>>> output = await trainer.forward(eval_batch)
145+
>>> logits = output.logits # [1, 4, vocab_size]
146+
>>> predictions = logits.argmax(dim=-1) # [1, 4]
58147
"""
59148
...
60149

61-
async def forward_backward(
62-
self, data: list[dict[str, torch.Tensor]]
150+
async def save_state(
151+
self, name: str | None = None, path: str | None = None
63152
) -> dict[str, Any]:
64-
"""Clear first, then forward+backward on all items in data.
153+
"""Save a checkpoint of the current trainer state.
65154
66-
Convenience wrapper equivalent to:
67-
clear_gradients() + accumulate_gradients() for each item
155+
Saves the complete training state including model weights, optimizer state,
156+
learning rate scheduler state, and current step counter. This checkpoint
157+
can be loaded later to resume training from this exact point.
68158
69-
Does NOT call optim_step() - you must call it separately.
159+
Args:
160+
name: Optional checkpoint name/identifier. If None, uses the current
161+
step number (e.g., "step-1000").
162+
path: Optional base directory or URI where checkpoint should be saved.
163+
If None, uses the default checkpoint directory configured at trainer
164+
creation. Supports different backends via URI schemes:
165+
- `/local/path` - local filesystem
166+
- `ts://key` - TorchStore
167+
- `s3://bucket/key` - S3
168+
169+
Location resolution:
170+
- Both provided: path/name (e.g., "/checkpoints" + "best" = "/checkpoints/best")
171+
- Only path: use path directly
172+
- Only name: default_dir/name
173+
- Neither: default_dir/step-{current_step}
70174
71175
Returns:
72-
dict with keys:
73-
- loss: float
74-
- metrics: dict[str, float]
176+
dict containing:
177+
- path: str - Full path where checkpoint was saved
178+
- step: int - Training step at which checkpoint was saved
179+
180+
Example:
181+
>>> # Save to default location with step number
182+
>>> result = await trainer.save_state() # => /default/step-1000
183+
>>>
184+
>>> # Save with custom name to default location
185+
>>> result = await trainer.save_state("best-model") # => /default/best-model
186+
>>>
187+
>>> # Save to custom base directory
188+
>>> result = await trainer.save_state("final", "/custom/checkpoints")
189+
>>> # => /custom/checkpoints/final
75190
"""
76191
...
77192

78-
async def save_state(self, name: str) -> dict[str, Any]:
79-
"""Save the checkpoint.
193+
async def load_state(self, path: str | None = None) -> dict[str, Any]:
194+
"""Load a previously saved checkpoint.
195+
196+
Restores the complete training state from a checkpoint, including model
197+
weights, optimizer state, learning rate scheduler state, and step counter.
198+
199+
Args:
200+
path: Optional path or URI to the checkpoint to load. If None, loads
201+
the most recent checkpoint from the default directory. Can be:
202+
- `/local/path/checkpoint` - local filesystem
203+
- `ts://key` - TorchStore
204+
- `s3://bucket/key` - S3
205+
206+
Returns:
207+
dict containing:
208+
- step: int - Training step from the loaded checkpoint
209+
- learning_rate: float - Learning rate from the loaded checkpoint
210+
211+
Example:
212+
>>> # Load latest checkpoint from default location
213+
>>> result = await trainer.load_state()
214+
>>> print(f"Resumed from step {result['step']}")
215+
>>>
216+
>>> # Load specific checkpoint by path
217+
>>> result = await trainer.load_state("/checkpoints/step-5000")
218+
>>>
219+
>>> # Load from TorchStore
220+
>>> result = await trainer.load_state("ts://checkpoint-key")
221+
"""
222+
...
223+
224+
async def save_weights(
225+
self, name: str | None = None, path: str | None = None
226+
) -> dict[str, Any]:
227+
"""Save model weights only (without optimizer/scheduler state).
228+
229+
Saves only the model weights in a format suitable for inference/sampling.
230+
This is lighter weight than save_state() since it excludes training state
231+
like optimizer and scheduler.
232+
233+
Args:
234+
name: Optional checkpoint name/identifier. If None, uses the current
235+
step number (e.g., "weights-step-1000").
236+
path: Optional base directory or URI where weights should be saved.
237+
If None, uses the default location configured at trainer creation.
238+
Supports different backends via URI schemes:
239+
- `/local/path` - local filesystem
240+
- `ts://key` - TorchStore
241+
- `s3://bucket/key` - S3
242+
243+
Location resolution:
244+
- Both provided: path/name
245+
- Only path: use path directly
246+
- Only name: default_dir/name
247+
- Neither: default_dir/step-{current_step}
80248
81249
Returns:
82-
dict with keys:
83-
- path: str
84-
- step: int
250+
dict containing:
251+
- path: str - Full URI where weights were saved
252+
- version: str | int - The name/version that was saved
253+
254+
Example:
255+
>>> # Save to default location with step number
256+
>>> result = await trainer.save_weights()
257+
>>>
258+
>>> # Save to TorchStore for inference server
259+
>>> result = await trainer.save_weights("policy-v1", "ts://policy-weights")
260+
>>> # → ts://policy-weights/policy-v1
261+
>>>
262+
>>> # Save to S3
263+
>>> result = await trainer.save_weights(path="s3://bucket/models/final")
85264
"""
86265
...
87266

88-
async def load_state(self, path: str) -> dict[str, Any]:
89-
"""Load checkpoint.
267+
async def get_info(self) -> TrainerInfo:
268+
"""Get static trainer and model metadata.
269+
270+
Returns information about the trainer configuration and model architecture
271+
that doesn't change during training.
90272
91273
Returns:
92-
dict with keys:
93-
- step: int
94-
- learning_rate: float
274+
TrainerInfo containing model name, step, config, and parallelism settings
275+
276+
Example:
277+
>>> info = await trainer.get_info()
278+
>>> print(f"Training {info.model_name} at step {info.step}")
279+
>>> print(f"Vocab size: {info.config['vocab_size']}")
280+
>>> print(f"Data parallel degree: {info.parallelism['dp_degree']}")
95281
"""
96282
...
97283

98-
async def save_weights_for_sampler(self, name: str) -> dict[str, Any]:
99-
"""Export weights for inference.
284+
async def get_status(self) -> TrainerStatus:
285+
"""Get current runtime status of the trainer.
286+
287+
Returns dynamic information about the trainer's current state that changes
288+
during training.
100289
101290
Returns:
102-
dict with keys:
103-
- path: str
104-
- version: str or int
291+
TrainerStatus containing current step and accumulated batch count
292+
293+
Example:
294+
>>> status = await trainer.get_status()
295+
>>> print(f"Current step: {status.step}")
296+
>>> if status.accumulated_microbatches > 0:
297+
>>> print(f"Warning: {status.accumulated_microbatches} "
298+
>>> f"batches accumulated without optimizer step")
105299
"""
106300
...
107301

108302
def get_tokenizer(self):
109-
"""Get the tokenizer.
303+
"""Get the tokenizer associated with this model.
304+
305+
Returns the tokenizer used for encoding/decoding text with this model.
306+
Useful for preprocessing inputs or decoding model outputs.
110307
111308
Returns:
112-
PreTrainedTokenizer
309+
PreTrainedTokenizer: The HuggingFace tokenizer for this model
310+
311+
Note:
312+
This is a synchronous method (not async) since tokenizer access is
313+
typically fast and doesn't require remote calls.
314+
315+
Example:
316+
>>> tokenizer = trainer.get_tokenizer()
317+
>>> tokens = tokenizer.encode("Hello world")
318+
>>> text = tokenizer.decode([1, 2, 3, 4])
113319
"""
114320
...

0 commit comments

Comments
 (0)