|
| 1 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Functional tests for the training callback system.""" |
| 16 | + |
| 17 | +import pytest |
| 18 | +import torch |
| 19 | + |
| 20 | +from megatron.bridge.models.llama import Llama32ModelProvider1B |
| 21 | +from megatron.bridge.training.callbacks import Callback, CallbackContext, CallbackManager |
| 22 | +from megatron.bridge.training.config import ( |
| 23 | + CheckpointConfig, |
| 24 | + ConfigContainer, |
| 25 | + DistributedDataParallelConfig, |
| 26 | + LoggerConfig, |
| 27 | + MockGPTDatasetConfig, |
| 28 | + OptimizerConfig, |
| 29 | + RNGConfig, |
| 30 | + SchedulerConfig, |
| 31 | + TokenizerConfig, |
| 32 | + TrainingConfig, |
| 33 | +) |
| 34 | +from megatron.bridge.training.gpt_step import forward_step |
| 35 | +from megatron.bridge.training.pretrain import pretrain |
| 36 | + |
| 37 | + |
| 38 | +class TrackingCallback(Callback): |
| 39 | + """Tracks event ordering and context field availability.""" |
| 40 | + |
| 41 | + def __init__(self): |
| 42 | + self.events: list[str] = [] |
| 43 | + self.context_snapshots: list[dict] = [] |
| 44 | + |
| 45 | + def _record(self, event_name: str, context: CallbackContext) -> None: |
| 46 | + self.events.append(event_name) |
| 47 | + self.context_snapshots.append( |
| 48 | + { |
| 49 | + "event": event_name, |
| 50 | + "has_state": context.state is not None, |
| 51 | + "has_model": context.model is not None and len(context.model) > 0, |
| 52 | + "has_user_state": context.user_state is not None, |
| 53 | + "has_optimizer": context.optimizer is not None, |
| 54 | + "has_scheduler": context.scheduler is not None, |
| 55 | + "has_loss_dict": context.loss_dict is not None, |
| 56 | + "has_grad_norm": context.grad_norm is not None, |
| 57 | + "has_skipped_iter": context.skipped_iter is not None, |
| 58 | + "has_total_loss_dict": context.total_loss_dict is not None, |
| 59 | + } |
| 60 | + ) |
| 61 | + |
| 62 | + def on_train_start(self, context: CallbackContext) -> None: |
| 63 | + self._record("on_train_start", context) |
| 64 | + |
| 65 | + def on_train_step_start(self, context: CallbackContext) -> None: |
| 66 | + self._record("on_train_step_start", context) |
| 67 | + |
| 68 | + def on_train_step_end(self, context: CallbackContext) -> None: |
| 69 | + self._record("on_train_step_end", context) |
| 70 | + |
| 71 | + def on_train_end(self, context: CallbackContext) -> None: |
| 72 | + self._record("on_train_end", context) |
| 73 | + |
| 74 | + def on_eval_start(self, context: CallbackContext) -> None: |
| 75 | + self._record("on_eval_start", context) |
| 76 | + |
| 77 | + def on_eval_step_start(self, context: CallbackContext) -> None: |
| 78 | + self._record("on_eval_step_start", context) |
| 79 | + |
| 80 | + def on_eval_step_end(self, context: CallbackContext) -> None: |
| 81 | + self._record("on_eval_step_end", context) |
| 82 | + |
| 83 | + def on_eval_end(self, context: CallbackContext) -> None: |
| 84 | + self._record("on_eval_end", context) |
| 85 | + |
| 86 | + def on_test_start(self, context: CallbackContext) -> None: |
| 87 | + self._record("on_test_start", context) |
| 88 | + |
| 89 | + def on_test_step_start(self, context: CallbackContext) -> None: |
| 90 | + self._record("on_test_step_start", context) |
| 91 | + |
| 92 | + def on_test_step_end(self, context: CallbackContext) -> None: |
| 93 | + self._record("on_test_step_end", context) |
| 94 | + |
| 95 | + def on_test_end(self, context: CallbackContext) -> None: |
| 96 | + self._record("on_test_end", context) |
| 97 | + |
| 98 | + def get_event_count(self, event_name: str) -> int: |
| 99 | + return sum(1 for e in self.events if e == event_name) |
| 100 | + |
| 101 | + def get_snapshots_for_event(self, event_name: str) -> list[dict]: |
| 102 | + return [s for s in self.context_snapshots if s["event"] == event_name] |
| 103 | + |
| 104 | + |
| 105 | +class UserStateCallback(Callback): |
| 106 | + """Tests user_state persistence across events.""" |
| 107 | + |
| 108 | + def __init__(self): |
| 109 | + self.step_values: list[int] = [] |
| 110 | + self.eval_read_values: list[int] = [] |
| 111 | + self.test_read_values: list[int] = [] |
| 112 | + self.final_count: int | None = None |
| 113 | + |
| 114 | + def on_train_start(self, context: CallbackContext) -> None: |
| 115 | + context.user_state["counter"] = 0 |
| 116 | + |
| 117 | + def on_train_step_end(self, context: CallbackContext) -> None: |
| 118 | + context.user_state["counter"] += 1 |
| 119 | + self.step_values.append(context.user_state["counter"]) |
| 120 | + |
| 121 | + def on_eval_start(self, context: CallbackContext) -> None: |
| 122 | + self.eval_read_values.append(context.user_state.get("counter", -1)) |
| 123 | + |
| 124 | + def on_test_start(self, context: CallbackContext) -> None: |
| 125 | + self.test_read_values.append(context.user_state.get("counter", -1)) |
| 126 | + |
| 127 | + def on_train_end(self, context: CallbackContext) -> None: |
| 128 | + self.final_count = context.user_state.get("counter", -1) |
| 129 | + |
| 130 | + |
| 131 | +class TestCallbacksEndToEnd: |
| 132 | + """Functional tests for callbacks in the training loop.""" |
| 133 | + |
| 134 | + @pytest.mark.run_only_on("GPU") |
| 135 | + def test_callbacks(self): |
| 136 | + """Comprehensive test of callback system with both registration patterns. |
| 137 | +
|
| 138 | + Tests in a single training run: |
| 139 | + 1. Class-based callbacks (TrackingCallback, UserStateCallback) |
| 140 | + 2. Functional callbacks (via register()) |
| 141 | + 3. Event firing counts and ordering |
| 142 | + 4. Context field availability at each event |
| 143 | + 5. user_state persistence across callback invocations |
| 144 | + """ |
| 145 | + |
| 146 | + # Training configuration |
| 147 | + # eval_interval doesn't evenly divide train_iters to avoid eval at last step |
| 148 | + # This ensures in-training eval only runs once (at step 5), not at step 8 |
| 149 | + train_iters = 8 |
| 150 | + eval_interval = 5 # Eval only at step 5 during training |
| 151 | + eval_iters = 2 |
| 152 | + |
| 153 | + model_cfg = Llama32ModelProvider1B( |
| 154 | + tensor_model_parallel_size=1, |
| 155 | + pipeline_model_parallel_size=1, |
| 156 | + context_parallel_size=1, |
| 157 | + sequence_parallel=False, |
| 158 | + attention_softmax_in_fp32=True, |
| 159 | + pipeline_dtype=torch.bfloat16, |
| 160 | + bf16=True, |
| 161 | + seq_length=512, |
| 162 | + make_vocab_size_divisible_by=128, |
| 163 | + vocab_size=None, |
| 164 | + num_layers=1, |
| 165 | + ) |
| 166 | + |
| 167 | + cfg = ConfigContainer( |
| 168 | + model=model_cfg, |
| 169 | + train=TrainingConfig( |
| 170 | + train_iters=train_iters, |
| 171 | + eval_interval=eval_interval, |
| 172 | + eval_iters=eval_iters, |
| 173 | + global_batch_size=8, |
| 174 | + micro_batch_size=1, |
| 175 | + exit_signal_handler=True, |
| 176 | + ), |
| 177 | + optimizer=OptimizerConfig( |
| 178 | + optimizer="adam", |
| 179 | + bf16=True, |
| 180 | + fp16=False, |
| 181 | + adam_beta1=0.9, |
| 182 | + adam_beta2=0.95, |
| 183 | + adam_eps=1e-5, |
| 184 | + use_distributed_optimizer=True, |
| 185 | + clip_grad=1.0, |
| 186 | + lr=3e-3, |
| 187 | + weight_decay=0.01, |
| 188 | + min_lr=1e-6, |
| 189 | + ), |
| 190 | + scheduler=SchedulerConfig( |
| 191 | + start_weight_decay=0.033, |
| 192 | + end_weight_decay=0.033, |
| 193 | + weight_decay_incr_style="constant", |
| 194 | + lr_decay_style="cosine", |
| 195 | + lr_warmup_iters=2, |
| 196 | + lr_warmup_init=0.0, |
| 197 | + lr_decay_iters=train_iters, |
| 198 | + override_opt_param_scheduler=True, |
| 199 | + ), |
| 200 | + ddp=DistributedDataParallelConfig( |
| 201 | + check_for_nan_in_grad=True, |
| 202 | + grad_reduce_in_fp32=True, |
| 203 | + overlap_grad_reduce=True, |
| 204 | + overlap_param_gather=True, |
| 205 | + average_in_collective=True, |
| 206 | + use_distributed_optimizer=True, |
| 207 | + ), |
| 208 | + dataset=MockGPTDatasetConfig( |
| 209 | + random_seed=1234, |
| 210 | + reset_attention_mask=False, |
| 211 | + reset_position_ids=False, |
| 212 | + eod_mask_loss=False, |
| 213 | + seq_length=512, |
| 214 | + num_dataset_builder_threads=1, |
| 215 | + data_sharding=True, |
| 216 | + dataloader_type="single", |
| 217 | + num_workers=1, |
| 218 | + ), |
| 219 | + logger=LoggerConfig(log_interval=5), |
| 220 | + tokenizer=TokenizerConfig( |
| 221 | + tokenizer_type="NullTokenizer", |
| 222 | + vocab_size=10000, |
| 223 | + ), |
| 224 | + checkpoint=CheckpointConfig(save=None), |
| 225 | + rng=RNGConfig(seed=1234), |
| 226 | + ) |
| 227 | + |
| 228 | + # Create callbacks |
| 229 | + tracking_callback = TrackingCallback() |
| 230 | + user_state_callback = UserStateCallback() |
| 231 | + |
| 232 | + # Track functional callback invocations |
| 233 | + functional_log: list[str] = [] |
| 234 | + |
| 235 | + # Create manager with both class-based and functional callbacks |
| 236 | + manager = CallbackManager() |
| 237 | + manager.add([tracking_callback, user_state_callback]) |
| 238 | + manager.register("on_train_start", lambda ctx: functional_log.append("fn_start")) |
| 239 | + manager.register("on_train_step_end", lambda ctx: functional_log.append("fn_step")) |
| 240 | + manager.register("on_train_end", lambda ctx: functional_log.append("fn_end")) |
| 241 | + |
| 242 | + # Run training |
| 243 | + pretrain(cfg, forward_step, callbacks=manager) |
| 244 | + |
| 245 | + # Verify event firing counts |
| 246 | + assert tracking_callback.get_event_count("on_train_start") == 1 |
| 247 | + assert tracking_callback.get_event_count("on_train_end") == 1 |
| 248 | + assert tracking_callback.get_event_count("on_train_step_start") == train_iters |
| 249 | + assert tracking_callback.get_event_count("on_train_step_end") == train_iters |
| 250 | + |
| 251 | + # Eval runs: 1 during training (step 5 only) + 1 post-training validation |
| 252 | + in_training_eval_runs = train_iters // eval_interval # 8 // 5 = 1 |
| 253 | + post_training_eval_runs = 1 # validation only (test uses on_test_* events) |
| 254 | + expected_eval_runs = in_training_eval_runs + post_training_eval_runs |
| 255 | + assert tracking_callback.get_event_count("on_eval_start") == expected_eval_runs |
| 256 | + assert tracking_callback.get_event_count("on_eval_end") == expected_eval_runs |
| 257 | + |
| 258 | + expected_eval_steps = expected_eval_runs * eval_iters |
| 259 | + assert tracking_callback.get_event_count("on_eval_step_start") == expected_eval_steps |
| 260 | + assert tracking_callback.get_event_count("on_eval_step_end") == expected_eval_steps |
| 261 | + |
| 262 | + # Test runs: 1 post-training test |
| 263 | + expected_test_runs = 1 |
| 264 | + assert tracking_callback.get_event_count("on_test_start") == expected_test_runs |
| 265 | + assert tracking_callback.get_event_count("on_test_end") == expected_test_runs |
| 266 | + |
| 267 | + expected_test_steps = expected_test_runs * eval_iters |
| 268 | + assert tracking_callback.get_event_count("on_test_step_start") == expected_test_steps |
| 269 | + assert tracking_callback.get_event_count("on_test_step_end") == expected_test_steps |
| 270 | + |
| 271 | + # Verify event order |
| 272 | + events = tracking_callback.events |
| 273 | + assert events[0] == "on_train_start", "First event should be on_train_start" |
| 274 | + # Post-training test is the final phase, so on_test_end is last |
| 275 | + assert events[-1] == "on_test_end", "Last event should be on_test_end" |
| 276 | + # on_train_end should come before post-training test |
| 277 | + train_end_idx = events.index("on_train_end") |
| 278 | + test_start_idx = events.index("on_test_start") |
| 279 | + assert train_end_idx < test_start_idx, "on_train_end should precede post-training test" |
| 280 | + |
| 281 | + # Verify step events come in pairs (step_end before next step_start) |
| 282 | + for i, event in enumerate(events): |
| 283 | + if event == "on_train_step_start": |
| 284 | + remaining = events[i + 1 :] |
| 285 | + next_step_start = ( |
| 286 | + remaining.index("on_train_step_start") if "on_train_step_start" in remaining else len(remaining) |
| 287 | + ) |
| 288 | + next_step_end = ( |
| 289 | + remaining.index("on_train_step_end") if "on_train_step_end" in remaining else len(remaining) |
| 290 | + ) |
| 291 | + assert next_step_end < next_step_start or next_step_start == len(remaining) |
| 292 | + |
| 293 | + # Verify context data availability |
| 294 | + for snapshot in tracking_callback.context_snapshots: |
| 295 | + assert snapshot["has_state"], f"{snapshot['event']} missing state" |
| 296 | + assert snapshot["has_model"], f"{snapshot['event']} missing model" |
| 297 | + assert snapshot["has_user_state"], f"{snapshot['event']} missing user_state" |
| 298 | + |
| 299 | + training_events = ["on_train_start", "on_train_step_start", "on_train_step_end", "on_train_end"] |
| 300 | + for snapshot in tracking_callback.context_snapshots: |
| 301 | + if snapshot["event"] in training_events: |
| 302 | + assert snapshot["has_optimizer"], f"{snapshot['event']} missing optimizer" |
| 303 | + assert snapshot["has_scheduler"], f"{snapshot['event']} missing scheduler" |
| 304 | + |
| 305 | + for snapshot in tracking_callback.get_snapshots_for_event("on_train_step_end"): |
| 306 | + assert snapshot["has_loss_dict"], "on_train_step_end missing loss_dict" |
| 307 | + assert snapshot["has_grad_norm"], "on_train_step_end missing grad_norm" |
| 308 | + assert snapshot["has_skipped_iter"], "on_train_step_end missing skipped_iter" |
| 309 | + |
| 310 | + for snapshot in tracking_callback.get_snapshots_for_event("on_eval_end"): |
| 311 | + assert snapshot["has_total_loss_dict"], "on_eval_end missing total_loss_dict" |
| 312 | + |
| 313 | + for snapshot in tracking_callback.get_snapshots_for_event("on_test_end"): |
| 314 | + assert snapshot["has_total_loss_dict"], "on_test_end missing total_loss_dict" |
| 315 | + |
| 316 | + # Verify user_state persistence (UserStateCallback) |
| 317 | + assert user_state_callback.final_count == train_iters, ( |
| 318 | + f"Final counter should be {train_iters}, got {user_state_callback.final_count}" |
| 319 | + ) |
| 320 | + assert user_state_callback.step_values == list(range(1, train_iters + 1)), ( |
| 321 | + f"Step values should be [1..{train_iters}], got {user_state_callback.step_values}" |
| 322 | + ) |
| 323 | + # In-training eval happens after step 5, counter should be 5 |
| 324 | + # Post-training validation reads counter=8 (final train_iters) |
| 325 | + assert user_state_callback.eval_read_values[0] == eval_interval, ( |
| 326 | + f"First eval should read counter={eval_interval}, got {user_state_callback.eval_read_values[0]}" |
| 327 | + ) |
| 328 | + assert user_state_callback.eval_read_values[-1] == train_iters, ( |
| 329 | + f"Post-training eval should read counter={train_iters}, got {user_state_callback.eval_read_values[-1]}" |
| 330 | + ) |
| 331 | + assert len(user_state_callback.eval_read_values) == expected_eval_runs, ( |
| 332 | + f"Should have {expected_eval_runs} eval reads, got {len(user_state_callback.eval_read_values)}" |
| 333 | + ) |
| 334 | + # Post-training test runs after training, reads counter=8 |
| 335 | + assert len(user_state_callback.test_read_values) == expected_test_runs, ( |
| 336 | + f"Should have {expected_test_runs} test reads, got {len(user_state_callback.test_read_values)}" |
| 337 | + ) |
| 338 | + assert user_state_callback.test_read_values[0] == train_iters, ( |
| 339 | + f"Test should read counter={train_iters}, got {user_state_callback.test_read_values[0]}" |
| 340 | + ) |
| 341 | + |
| 342 | + # Verify functional callbacks fired |
| 343 | + assert functional_log[0] == "fn_start", "Functional on_train_start should fire" |
| 344 | + assert functional_log[-1] == "fn_end", "Functional on_train_end should fire" |
| 345 | + assert functional_log.count("fn_step") == train_iters, ( |
| 346 | + f"Functional on_train_step_end should fire {train_iters} times" |
| 347 | + ) |
0 commit comments