Skip to content

Commit 1d82b25

Browse files
authored
feat: Support callbacks for third-party extensions (#2063)
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
1 parent 38858ef commit 1d82b25

File tree

10 files changed

+1666
-9
lines changed

10 files changed

+1666
-9
lines changed

.github/workflows/cicd-main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ jobs:
378378
matrix:
379379
include:
380380
- script: L2_Launch_training
381-
timeout: 40
381+
timeout: 50
382382
- script: L2_Launch_converter
383383
- script: L2_Launch_models_deepseek
384384
- script: L2_Launch_models_gemma

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ training/cpu-offloading.md
5050
training/peft.md
5151
training/packed-sequences.md
5252
training/distillation.md
53+
training/callbacks.md
5354
```
5455

5556
```{toctree}

docs/training/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ This directory contains comprehensive documentation for training and customizing
2222
→ Explore [Performance Guide](../performance-guide.md) and [Performance Summary](../performance-summary.md)
2323

2424
**🔧 Customize training**
25-
→ See [PEFT](peft.md), [Distillation](distillation.md), and [Entry Points](entry-points.md)
25+
→ See [PEFT](peft.md), [Distillation](distillation.md), [Entry Points](entry-points.md), and [Callbacks](callbacks.md)
2626

2727
## Core Training Documentation
2828

@@ -61,6 +61,7 @@ This directory contains comprehensive documentation for training and customizing
6161
| **[Packed Sequences](packed-sequences.md)** | Sequence packing for efficiency | Optimizing data loading |
6262
| **[Distillation](distillation.md)** | Knowledge distillation techniques | Transferring knowledge between models |
6363
| **[Checkpointing](checkpointing.md)** | Checkpoint saving, loading, and resuming | Managing training state |
64+
| **[Callbacks](callbacks.md)** | Inject custom logic into training loop | Custom logging, metrics, third-party integrations |
6465

6566
## Training Workflow
6667

@@ -113,6 +114,7 @@ A typical training workflow involves:
113114
1. [PEFT](peft.md) - Parameter-efficient fine-tuning
114115
2. [Distillation](distillation.md) - Knowledge distillation
115116
3. [Entry Points](entry-points.md) - Custom training workflows
117+
4. [Callbacks](callbacks.md) - Inject custom logic (third-party integrations)
116118

117119
---
118120

docs/training/callbacks.md

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Callbacks
2+
3+
Megatron Bridge provides a lightweight callback system for injecting custom logic into the training and evaluation loop without modifying framework code. This is ideal for propietary integrations or custom logging and metrics tracking.
4+
5+
## Quick Start
6+
7+
### Class-Based Callbacks
8+
9+
Subclass {py:class}`bridge.training.callbacks.Callback` and override event methods:
10+
11+
```python
12+
from megatron.bridge.training.callbacks import Callback
13+
from megatron.bridge.training.pretrain import pretrain
14+
15+
class MyCallback(Callback):
16+
def on_train_start(self, context):
17+
context.user_state['start_time'] = time.time()
18+
print(f"Training started at step {context.state.train_state.step}")
19+
20+
def on_train_step_end(self, context):
21+
if context.loss_dict:
22+
print(f"Step {context.state.train_state.step}: loss={context.loss_dict}")
23+
24+
def on_train_end(self, context):
25+
elapsed = time.time() - context.user_state['start_time']
26+
print(f"Training completed in {elapsed:.2f}s")
27+
28+
# Pass callbacks to pretrain
29+
pretrain(config, forward_step_func, callbacks=[MyCallback()])
30+
```
31+
32+
### Functional Callbacks
33+
34+
Register functions directly with {py:class}`bridge.training.callbacks.CallbackManager`:
35+
36+
```python
37+
from megatron.bridge.training.callbacks import CallbackManager
38+
from megatron.bridge.training import pretrain
39+
40+
def log_step(context):
41+
step = context.state.train_state.step
42+
if context.loss_dict:
43+
print(f"Step {step}: {context.loss_dict}")
44+
45+
callback_manager = CallbackManager()
46+
callback_manager.register("on_train_step_end", log_step)
47+
48+
pretrain(config, forward_step_func, callbacks=callback_manager)
49+
```
50+
51+
### Mixing Both Patterns
52+
53+
Both registration patterns can be combined:
54+
55+
```python
56+
manager = CallbackManager()
57+
manager.add(MyCallback())
58+
manager.add([TimingCallback(), MetricsCallback()])
59+
manager.register("on_eval_end", lambda ctx: print("Evaluation complete!"))
60+
61+
pretrain(config, forward_step_func, callbacks=manager)
62+
```
63+
64+
## Available Events
65+
66+
### Training Events
67+
68+
| Event | When Fired | Available Context Fields |
69+
|-------|------------|-------------------------|
70+
| `on_train_start` | After `model.train()`, before training loop | `state`, `model`, `user_state`, `optimizer`, `scheduler` |
71+
| `on_train_step_start` | Before each training step | `state`, `model`, `user_state`, `optimizer`, `scheduler` |
72+
| `on_train_step_end` | After each training step | `state`, `model`, `user_state`, `optimizer`, `scheduler`, `loss_dict`, `grad_norm`, `skipped_iter` |
73+
| `on_train_end` | After training loop completes | `state`, `model`, `user_state`, `optimizer`, `scheduler` |
74+
75+
### Validation Events
76+
77+
| Event | When Fired | Available Context Fields |
78+
|-------|------------|-------------------------|
79+
| `on_eval_start` | After `model.eval()`, before validation loop | `state`, `model`, `user_state` |
80+
| `on_eval_step_start` | Before each validation step | `state`, `model`, `user_state` |
81+
| `on_eval_step_end` | After each validation step | `state`, `model`, `user_state` |
82+
| `on_eval_end` | After validation completes | `state`, `model`, `user_state`, `total_loss_dict` |
83+
84+
### Test Events
85+
86+
| Event | When Fired | Available Context Fields |
87+
|-------|------------|-------------------------|
88+
| `on_test_start` | After `model.eval()`, before test loop | `state`, `model`, `user_state` |
89+
| `on_test_step_start` | Before each test step | `state`, `model`, `user_state` |
90+
| `on_test_step_end` | After each test step | `state`, `model`, `user_state` |
91+
| `on_test_end` | After test completes | `state`, `model`, `user_state`, `total_loss_dict` |
92+
93+
## CallbackContext
94+
95+
The {py:class}`bridge.training.callbacks.CallbackContext` provides access to framework state:
96+
97+
### Always Available
98+
99+
- **`state`**: {py:class}`bridge.training.state.GlobalState` - Contains config, train_state, timers, and loggers
100+
- **`model`**: List of model chunks
101+
- **`user_state`**: Mutable dict for storing data across callback invocations
102+
103+
### Training Events Only
104+
105+
- **`optimizer`**: The optimizer instance
106+
- **`scheduler`**: Learning rate scheduler
107+
108+
### Event-Specific Fields
109+
110+
- **`loss_dict`** (`on_train_step_end`): Dictionary of reduced losses from the training step
111+
- **`grad_norm`** (`on_train_step_end`): Gradient norm (if computed)
112+
- **`skipped_iter`** (`on_train_step_end`): Whether the iteration was skipped
113+
- **`total_loss_dict`** (`on_eval_end`, `on_test_end`): Aggregated evaluation/test losses
114+
115+
## User State
116+
117+
The `CallbackManager` owns a `user_state` dictionary that persists across all callback invocations during a training run. Use it to share data between callbacks or accumulate metrics:
118+
119+
```python
120+
class StepCounterCallback(Callback):
121+
def on_train_start(self, context):
122+
context.user_state['callback_step_count'] = 0
123+
124+
def on_train_step_end(self, context):
125+
context.user_state['callback_step_count'] += 1
126+
127+
def on_train_end(self, context):
128+
print(f"Callback saw {context.user_state['callback_step_count']} steps")
129+
```
130+
131+
## Distributed Training
132+
133+
Callbacks fire on **all ranks** without framework-level synchronization. If your callback should only run on specific ranks, add guards:
134+
135+
```python
136+
import torch.distributed as dist
137+
138+
class RankZeroCallback(Callback):
139+
def on_train_step_end(self, context):
140+
if dist.get_rank() == 0:
141+
print(f"Step {context.state.train_state.step} complete")
142+
```
143+
144+
## Exception Handling
145+
146+
Exceptions from callbacks propagate to the caller. The framework does not catch or handle callback exceptions. If your callback might fail, wrap it in a try-except:
147+
148+
```python
149+
def safe_callback(context):
150+
try:
151+
# Your logic here
152+
external_service.log(context.loss_dict)
153+
except Exception as e:
154+
print(f"Callback failed: {e}")
155+
# Don't re-raise to avoid stopping training
156+
```
157+
158+
## Execution Order
159+
160+
Callbacks fire in registration order:
161+
162+
1. Callbacks added via `add()` fire in the order they were added
163+
2. Callbacks registered via `register()` fire in the order they were registered
164+
3. If both methods are used, the order depends on when each was called
165+
166+
## Introspection
167+
168+
Query registered callbacks:
169+
170+
```python
171+
manager = CallbackManager()
172+
manager.register("on_train_start", my_fn)
173+
174+
# Check if any callbacks exist for an event
175+
if manager.has_callbacks("on_train_start"):
176+
print("Callbacks registered for on_train_start")
177+
178+
# List all callbacks for an event
179+
callbacks = manager.list_callbacks("on_train_start")
180+
print(f"Found {len(callbacks)} callbacks")
181+
182+
# Get all valid event names
183+
print(manager.events) # frozenset of valid event names
184+
```
185+
186+
## Design Principles
187+
188+
The callback system follows these principles:
189+
190+
1. **First-Party Isolation**: Framework code never uses callbacks for its own logic. Callbacks are strictly for third-party extensions.
191+
192+
2. **Zero Overhead**: When no callbacks are registered, there is zero performance overhead.
193+
194+
3. **Safety**: Callbacks receive framework state but modifying it is at the user's own risk. The framework makes no guarantees about the effects of modifications.
195+
196+
## Examples
197+
198+
### Proprietary Metrics
199+
200+
```python
201+
class ProprietaryMetricsCallback(Callback):
202+
"""Send metrics to internal monitoring system."""
203+
204+
def __init__(self, endpoint: str):
205+
self.client = InternalMetricsClient(endpoint)
206+
207+
def on_train_step_end(self, context):
208+
if context.loss_dict:
209+
self.client.send({
210+
"step": context.state.train_state.step,
211+
"loss": context.loss_dict.get("lm loss"),
212+
"grad_norm": context.grad_norm,
213+
"cluster_id": os.environ.get("CLUSTER_ID"),
214+
})
215+
```
216+
217+
## API Reference
218+
219+
- {py:class}`bridge.training.callbacks.Callback`
220+
- {py:class}`bridge.training.callbacks.CallbackContext`
221+
- {py:class}`bridge.training.callbacks.CallbackManager`
222+
- {py:func}`bridge.training.callbacks.normalize_callbacks`
223+
- {py:func}`bridge.training.callbacks.should_fire`

0 commit comments

Comments
 (0)