|
| 1 | +# Callbacks |
| 2 | + |
| 3 | +Megatron Bridge provides a lightweight callback system for injecting custom logic into the training and evaluation loop without modifying framework code. This is ideal for propietary integrations or custom logging and metrics tracking. |
| 4 | + |
| 5 | +## Quick Start |
| 6 | + |
| 7 | +### Class-Based Callbacks |
| 8 | + |
| 9 | +Subclass {py:class}`bridge.training.callbacks.Callback` and override event methods: |
| 10 | + |
| 11 | +```python |
| 12 | +from megatron.bridge.training.callbacks import Callback |
| 13 | +from megatron.bridge.training.pretrain import pretrain |
| 14 | + |
| 15 | +class MyCallback(Callback): |
| 16 | + def on_train_start(self, context): |
| 17 | + context.user_state['start_time'] = time.time() |
| 18 | + print(f"Training started at step {context.state.train_state.step}") |
| 19 | + |
| 20 | + def on_train_step_end(self, context): |
| 21 | + if context.loss_dict: |
| 22 | + print(f"Step {context.state.train_state.step}: loss={context.loss_dict}") |
| 23 | + |
| 24 | + def on_train_end(self, context): |
| 25 | + elapsed = time.time() - context.user_state['start_time'] |
| 26 | + print(f"Training completed in {elapsed:.2f}s") |
| 27 | + |
| 28 | +# Pass callbacks to pretrain |
| 29 | +pretrain(config, forward_step_func, callbacks=[MyCallback()]) |
| 30 | +``` |
| 31 | + |
| 32 | +### Functional Callbacks |
| 33 | + |
| 34 | +Register functions directly with {py:class}`bridge.training.callbacks.CallbackManager`: |
| 35 | + |
| 36 | +```python |
| 37 | +from megatron.bridge.training.callbacks import CallbackManager |
| 38 | +from megatron.bridge.training import pretrain |
| 39 | + |
| 40 | +def log_step(context): |
| 41 | + step = context.state.train_state.step |
| 42 | + if context.loss_dict: |
| 43 | + print(f"Step {step}: {context.loss_dict}") |
| 44 | + |
| 45 | +callback_manager = CallbackManager() |
| 46 | +callback_manager.register("on_train_step_end", log_step) |
| 47 | + |
| 48 | +pretrain(config, forward_step_func, callbacks=callback_manager) |
| 49 | +``` |
| 50 | + |
| 51 | +### Mixing Both Patterns |
| 52 | + |
| 53 | +Both registration patterns can be combined: |
| 54 | + |
| 55 | +```python |
| 56 | +manager = CallbackManager() |
| 57 | +manager.add(MyCallback()) |
| 58 | +manager.add([TimingCallback(), MetricsCallback()]) |
| 59 | +manager.register("on_eval_end", lambda ctx: print("Evaluation complete!")) |
| 60 | + |
| 61 | +pretrain(config, forward_step_func, callbacks=manager) |
| 62 | +``` |
| 63 | + |
| 64 | +## Available Events |
| 65 | + |
| 66 | +### Training Events |
| 67 | + |
| 68 | +| Event | When Fired | Available Context Fields | |
| 69 | +|-------|------------|-------------------------| |
| 70 | +| `on_train_start` | After `model.train()`, before training loop | `state`, `model`, `user_state`, `optimizer`, `scheduler` | |
| 71 | +| `on_train_step_start` | Before each training step | `state`, `model`, `user_state`, `optimizer`, `scheduler` | |
| 72 | +| `on_train_step_end` | After each training step | `state`, `model`, `user_state`, `optimizer`, `scheduler`, `loss_dict`, `grad_norm`, `skipped_iter` | |
| 73 | +| `on_train_end` | After training loop completes | `state`, `model`, `user_state`, `optimizer`, `scheduler` | |
| 74 | + |
| 75 | +### Validation Events |
| 76 | + |
| 77 | +| Event | When Fired | Available Context Fields | |
| 78 | +|-------|------------|-------------------------| |
| 79 | +| `on_eval_start` | After `model.eval()`, before validation loop | `state`, `model`, `user_state` | |
| 80 | +| `on_eval_step_start` | Before each validation step | `state`, `model`, `user_state` | |
| 81 | +| `on_eval_step_end` | After each validation step | `state`, `model`, `user_state` | |
| 82 | +| `on_eval_end` | After validation completes | `state`, `model`, `user_state`, `total_loss_dict` | |
| 83 | + |
| 84 | +### Test Events |
| 85 | + |
| 86 | +| Event | When Fired | Available Context Fields | |
| 87 | +|-------|------------|-------------------------| |
| 88 | +| `on_test_start` | After `model.eval()`, before test loop | `state`, `model`, `user_state` | |
| 89 | +| `on_test_step_start` | Before each test step | `state`, `model`, `user_state` | |
| 90 | +| `on_test_step_end` | After each test step | `state`, `model`, `user_state` | |
| 91 | +| `on_test_end` | After test completes | `state`, `model`, `user_state`, `total_loss_dict` | |
| 92 | + |
| 93 | +## CallbackContext |
| 94 | + |
| 95 | +The {py:class}`bridge.training.callbacks.CallbackContext` provides access to framework state: |
| 96 | + |
| 97 | +### Always Available |
| 98 | + |
| 99 | +- **`state`**: {py:class}`bridge.training.state.GlobalState` - Contains config, train_state, timers, and loggers |
| 100 | +- **`model`**: List of model chunks |
| 101 | +- **`user_state`**: Mutable dict for storing data across callback invocations |
| 102 | + |
| 103 | +### Training Events Only |
| 104 | + |
| 105 | +- **`optimizer`**: The optimizer instance |
| 106 | +- **`scheduler`**: Learning rate scheduler |
| 107 | + |
| 108 | +### Event-Specific Fields |
| 109 | + |
| 110 | +- **`loss_dict`** (`on_train_step_end`): Dictionary of reduced losses from the training step |
| 111 | +- **`grad_norm`** (`on_train_step_end`): Gradient norm (if computed) |
| 112 | +- **`skipped_iter`** (`on_train_step_end`): Whether the iteration was skipped |
| 113 | +- **`total_loss_dict`** (`on_eval_end`, `on_test_end`): Aggregated evaluation/test losses |
| 114 | + |
| 115 | +## User State |
| 116 | + |
| 117 | +The `CallbackManager` owns a `user_state` dictionary that persists across all callback invocations during a training run. Use it to share data between callbacks or accumulate metrics: |
| 118 | + |
| 119 | +```python |
| 120 | +class StepCounterCallback(Callback): |
| 121 | + def on_train_start(self, context): |
| 122 | + context.user_state['callback_step_count'] = 0 |
| 123 | + |
| 124 | + def on_train_step_end(self, context): |
| 125 | + context.user_state['callback_step_count'] += 1 |
| 126 | + |
| 127 | + def on_train_end(self, context): |
| 128 | + print(f"Callback saw {context.user_state['callback_step_count']} steps") |
| 129 | +``` |
| 130 | + |
| 131 | +## Distributed Training |
| 132 | + |
| 133 | +Callbacks fire on **all ranks** without framework-level synchronization. If your callback should only run on specific ranks, add guards: |
| 134 | + |
| 135 | +```python |
| 136 | +import torch.distributed as dist |
| 137 | + |
| 138 | +class RankZeroCallback(Callback): |
| 139 | + def on_train_step_end(self, context): |
| 140 | + if dist.get_rank() == 0: |
| 141 | + print(f"Step {context.state.train_state.step} complete") |
| 142 | +``` |
| 143 | + |
| 144 | +## Exception Handling |
| 145 | + |
| 146 | +Exceptions from callbacks propagate to the caller. The framework does not catch or handle callback exceptions. If your callback might fail, wrap it in a try-except: |
| 147 | + |
| 148 | +```python |
| 149 | +def safe_callback(context): |
| 150 | + try: |
| 151 | + # Your logic here |
| 152 | + external_service.log(context.loss_dict) |
| 153 | + except Exception as e: |
| 154 | + print(f"Callback failed: {e}") |
| 155 | + # Don't re-raise to avoid stopping training |
| 156 | +``` |
| 157 | + |
| 158 | +## Execution Order |
| 159 | + |
| 160 | +Callbacks fire in registration order: |
| 161 | + |
| 162 | +1. Callbacks added via `add()` fire in the order they were added |
| 163 | +2. Callbacks registered via `register()` fire in the order they were registered |
| 164 | +3. If both methods are used, the order depends on when each was called |
| 165 | + |
| 166 | +## Introspection |
| 167 | + |
| 168 | +Query registered callbacks: |
| 169 | + |
| 170 | +```python |
| 171 | +manager = CallbackManager() |
| 172 | +manager.register("on_train_start", my_fn) |
| 173 | + |
| 174 | +# Check if any callbacks exist for an event |
| 175 | +if manager.has_callbacks("on_train_start"): |
| 176 | + print("Callbacks registered for on_train_start") |
| 177 | + |
| 178 | +# List all callbacks for an event |
| 179 | +callbacks = manager.list_callbacks("on_train_start") |
| 180 | +print(f"Found {len(callbacks)} callbacks") |
| 181 | + |
| 182 | +# Get all valid event names |
| 183 | +print(manager.events) # frozenset of valid event names |
| 184 | +``` |
| 185 | + |
| 186 | +## Design Principles |
| 187 | + |
| 188 | +The callback system follows these principles: |
| 189 | + |
| 190 | +1. **First-Party Isolation**: Framework code never uses callbacks for its own logic. Callbacks are strictly for third-party extensions. |
| 191 | + |
| 192 | +2. **Zero Overhead**: When no callbacks are registered, there is zero performance overhead. |
| 193 | + |
| 194 | +3. **Safety**: Callbacks receive framework state but modifying it is at the user's own risk. The framework makes no guarantees about the effects of modifications. |
| 195 | + |
| 196 | +## Examples |
| 197 | + |
| 198 | +### Proprietary Metrics |
| 199 | + |
| 200 | +```python |
| 201 | +class ProprietaryMetricsCallback(Callback): |
| 202 | + """Send metrics to internal monitoring system.""" |
| 203 | + |
| 204 | + def __init__(self, endpoint: str): |
| 205 | + self.client = InternalMetricsClient(endpoint) |
| 206 | + |
| 207 | + def on_train_step_end(self, context): |
| 208 | + if context.loss_dict: |
| 209 | + self.client.send({ |
| 210 | + "step": context.state.train_state.step, |
| 211 | + "loss": context.loss_dict.get("lm loss"), |
| 212 | + "grad_norm": context.grad_norm, |
| 213 | + "cluster_id": os.environ.get("CLUSTER_ID"), |
| 214 | + }) |
| 215 | +``` |
| 216 | + |
| 217 | +## API Reference |
| 218 | + |
| 219 | +- {py:class}`bridge.training.callbacks.Callback` |
| 220 | +- {py:class}`bridge.training.callbacks.CallbackContext` |
| 221 | +- {py:class}`bridge.training.callbacks.CallbackManager` |
| 222 | +- {py:func}`bridge.training.callbacks.normalize_callbacks` |
| 223 | +- {py:func}`bridge.training.callbacks.should_fire` |
0 commit comments