Skip to content

Commit 7e31023

Browse files
committed
add tests
Signed-off-by: Ananth Subramaniam <[email protected]>
1 parent 726dfbf commit 7e31023

File tree

2 files changed

+920
-0
lines changed

2 files changed

+920
-0
lines changed
Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Functional tests for the training callback system."""
16+
17+
import pytest
18+
import torch
19+
20+
from megatron.bridge.models.llama import Llama32ModelProvider1B
21+
from megatron.bridge.training.callbacks import Callback, CallbackContext, CallbackManager
22+
from megatron.bridge.training.config import (
23+
CheckpointConfig,
24+
ConfigContainer,
25+
DistributedDataParallelConfig,
26+
LoggerConfig,
27+
MockGPTDatasetConfig,
28+
OptimizerConfig,
29+
RNGConfig,
30+
SchedulerConfig,
31+
TokenizerConfig,
32+
TrainingConfig,
33+
)
34+
from megatron.bridge.training.gpt_step import forward_step
35+
from megatron.bridge.training.pretrain import pretrain
36+
37+
38+
class TrackingCallback(Callback):
39+
"""Tracks event ordering and context field availability."""
40+
41+
def __init__(self):
42+
self.events: list[str] = []
43+
self.context_snapshots: list[dict] = []
44+
45+
def _record(self, event_name: str, context: CallbackContext) -> None:
46+
self.events.append(event_name)
47+
self.context_snapshots.append(
48+
{
49+
"event": event_name,
50+
"has_state": context.state is not None,
51+
"has_model": context.model is not None and len(context.model) > 0,
52+
"has_user_state": context.user_state is not None,
53+
"has_optimizer": context.optimizer is not None,
54+
"has_scheduler": context.scheduler is not None,
55+
"has_loss_dict": context.loss_dict is not None,
56+
"has_grad_norm": context.grad_norm is not None,
57+
"has_skipped_iter": context.skipped_iter is not None,
58+
"has_total_loss_dict": context.total_loss_dict is not None,
59+
}
60+
)
61+
62+
def on_train_start(self, context: CallbackContext) -> None:
63+
self._record("on_train_start", context)
64+
65+
def on_train_step_start(self, context: CallbackContext) -> None:
66+
self._record("on_train_step_start", context)
67+
68+
def on_train_step_end(self, context: CallbackContext) -> None:
69+
self._record("on_train_step_end", context)
70+
71+
def on_train_end(self, context: CallbackContext) -> None:
72+
self._record("on_train_end", context)
73+
74+
def on_eval_start(self, context: CallbackContext) -> None:
75+
self._record("on_eval_start", context)
76+
77+
def on_eval_step_start(self, context: CallbackContext) -> None:
78+
self._record("on_eval_step_start", context)
79+
80+
def on_eval_step_end(self, context: CallbackContext) -> None:
81+
self._record("on_eval_step_end", context)
82+
83+
def on_eval_end(self, context: CallbackContext) -> None:
84+
self._record("on_eval_end", context)
85+
86+
def on_test_start(self, context: CallbackContext) -> None:
87+
self._record("on_test_start", context)
88+
89+
def on_test_step_start(self, context: CallbackContext) -> None:
90+
self._record("on_test_step_start", context)
91+
92+
def on_test_step_end(self, context: CallbackContext) -> None:
93+
self._record("on_test_step_end", context)
94+
95+
def on_test_end(self, context: CallbackContext) -> None:
96+
self._record("on_test_end", context)
97+
98+
def get_event_count(self, event_name: str) -> int:
99+
return sum(1 for e in self.events if e == event_name)
100+
101+
def get_snapshots_for_event(self, event_name: str) -> list[dict]:
102+
return [s for s in self.context_snapshots if s["event"] == event_name]
103+
104+
105+
class UserStateCallback(Callback):
106+
"""Tests user_state persistence across events."""
107+
108+
def __init__(self):
109+
self.step_values: list[int] = []
110+
self.eval_read_values: list[int] = []
111+
self.test_read_values: list[int] = []
112+
self.final_count: int | None = None
113+
114+
def on_train_start(self, context: CallbackContext) -> None:
115+
context.user_state["counter"] = 0
116+
117+
def on_train_step_end(self, context: CallbackContext) -> None:
118+
context.user_state["counter"] += 1
119+
self.step_values.append(context.user_state["counter"])
120+
121+
def on_eval_start(self, context: CallbackContext) -> None:
122+
self.eval_read_values.append(context.user_state.get("counter", -1))
123+
124+
def on_test_start(self, context: CallbackContext) -> None:
125+
self.test_read_values.append(context.user_state.get("counter", -1))
126+
127+
def on_train_end(self, context: CallbackContext) -> None:
128+
self.final_count = context.user_state.get("counter", -1)
129+
130+
131+
class TestCallbacksEndToEnd:
132+
"""Functional tests for callbacks in the training loop."""
133+
134+
@pytest.mark.run_only_on("GPU")
135+
def test_callbacks(self):
136+
"""Comprehensive test of callback system with both registration patterns.
137+
138+
Tests in a single training run:
139+
1. Class-based callbacks (TrackingCallback, UserStateCallback)
140+
2. Functional callbacks (via register())
141+
3. Event firing counts and ordering
142+
4. Context field availability at each event
143+
5. user_state persistence across callback invocations
144+
"""
145+
146+
# Training configuration
147+
# eval_interval doesn't evenly divide train_iters to avoid eval at last step
148+
# This ensures in-training eval only runs once (at step 5), not at step 8
149+
train_iters = 8
150+
eval_interval = 5 # Eval only at step 5 during training
151+
eval_iters = 2
152+
153+
model_cfg = Llama32ModelProvider1B(
154+
tensor_model_parallel_size=1,
155+
pipeline_model_parallel_size=1,
156+
context_parallel_size=1,
157+
sequence_parallel=False,
158+
attention_softmax_in_fp32=True,
159+
pipeline_dtype=torch.bfloat16,
160+
bf16=True,
161+
seq_length=512,
162+
make_vocab_size_divisible_by=128,
163+
vocab_size=None,
164+
num_layers=1,
165+
)
166+
167+
cfg = ConfigContainer(
168+
model=model_cfg,
169+
train=TrainingConfig(
170+
train_iters=train_iters,
171+
eval_interval=eval_interval,
172+
eval_iters=eval_iters,
173+
global_batch_size=8,
174+
micro_batch_size=1,
175+
exit_signal_handler=True,
176+
),
177+
optimizer=OptimizerConfig(
178+
optimizer="adam",
179+
bf16=True,
180+
fp16=False,
181+
adam_beta1=0.9,
182+
adam_beta2=0.95,
183+
adam_eps=1e-5,
184+
use_distributed_optimizer=True,
185+
clip_grad=1.0,
186+
lr=3e-3,
187+
weight_decay=0.01,
188+
min_lr=1e-6,
189+
),
190+
scheduler=SchedulerConfig(
191+
start_weight_decay=0.033,
192+
end_weight_decay=0.033,
193+
weight_decay_incr_style="constant",
194+
lr_decay_style="cosine",
195+
lr_warmup_iters=2,
196+
lr_warmup_init=0.0,
197+
lr_decay_iters=train_iters,
198+
override_opt_param_scheduler=True,
199+
),
200+
ddp=DistributedDataParallelConfig(
201+
check_for_nan_in_grad=True,
202+
grad_reduce_in_fp32=True,
203+
overlap_grad_reduce=True,
204+
overlap_param_gather=True,
205+
average_in_collective=True,
206+
use_distributed_optimizer=True,
207+
),
208+
dataset=MockGPTDatasetConfig(
209+
random_seed=1234,
210+
reset_attention_mask=False,
211+
reset_position_ids=False,
212+
eod_mask_loss=False,
213+
seq_length=512,
214+
num_dataset_builder_threads=1,
215+
data_sharding=True,
216+
dataloader_type="single",
217+
num_workers=1,
218+
),
219+
logger=LoggerConfig(log_interval=5),
220+
tokenizer=TokenizerConfig(
221+
tokenizer_type="NullTokenizer",
222+
vocab_size=10000,
223+
),
224+
checkpoint=CheckpointConfig(save=None),
225+
rng=RNGConfig(seed=1234),
226+
)
227+
228+
# Create callbacks
229+
tracking_callback = TrackingCallback()
230+
user_state_callback = UserStateCallback()
231+
232+
# Track functional callback invocations
233+
functional_log: list[str] = []
234+
235+
# Create manager with both class-based and functional callbacks
236+
manager = CallbackManager()
237+
manager.add([tracking_callback, user_state_callback])
238+
manager.register("on_train_start", lambda ctx: functional_log.append("fn_start"))
239+
manager.register("on_train_step_end", lambda ctx: functional_log.append("fn_step"))
240+
manager.register("on_train_end", lambda ctx: functional_log.append("fn_end"))
241+
242+
# Run training
243+
pretrain(cfg, forward_step, callbacks=manager)
244+
245+
# Verify event firing counts
246+
assert tracking_callback.get_event_count("on_train_start") == 1
247+
assert tracking_callback.get_event_count("on_train_end") == 1
248+
assert tracking_callback.get_event_count("on_train_step_start") == train_iters
249+
assert tracking_callback.get_event_count("on_train_step_end") == train_iters
250+
251+
# Eval runs: 1 during training (step 5 only) + 1 post-training validation
252+
in_training_eval_runs = train_iters // eval_interval # 8 // 5 = 1
253+
post_training_eval_runs = 1 # validation only (test uses on_test_* events)
254+
expected_eval_runs = in_training_eval_runs + post_training_eval_runs
255+
assert tracking_callback.get_event_count("on_eval_start") == expected_eval_runs
256+
assert tracking_callback.get_event_count("on_eval_end") == expected_eval_runs
257+
258+
expected_eval_steps = expected_eval_runs * eval_iters
259+
assert tracking_callback.get_event_count("on_eval_step_start") == expected_eval_steps
260+
assert tracking_callback.get_event_count("on_eval_step_end") == expected_eval_steps
261+
262+
# Test runs: 1 post-training test
263+
expected_test_runs = 1
264+
assert tracking_callback.get_event_count("on_test_start") == expected_test_runs
265+
assert tracking_callback.get_event_count("on_test_end") == expected_test_runs
266+
267+
expected_test_steps = expected_test_runs * eval_iters
268+
assert tracking_callback.get_event_count("on_test_step_start") == expected_test_steps
269+
assert tracking_callback.get_event_count("on_test_step_end") == expected_test_steps
270+
271+
# Verify event order
272+
events = tracking_callback.events
273+
assert events[0] == "on_train_start", "First event should be on_train_start"
274+
# Post-training test is the final phase, so on_test_end is last
275+
assert events[-1] == "on_test_end", "Last event should be on_test_end"
276+
# on_train_end should come before post-training test
277+
train_end_idx = events.index("on_train_end")
278+
test_start_idx = events.index("on_test_start")
279+
assert train_end_idx < test_start_idx, "on_train_end should precede post-training test"
280+
281+
# Verify step events come in pairs (step_end before next step_start)
282+
for i, event in enumerate(events):
283+
if event == "on_train_step_start":
284+
remaining = events[i + 1 :]
285+
next_step_start = (
286+
remaining.index("on_train_step_start") if "on_train_step_start" in remaining else len(remaining)
287+
)
288+
next_step_end = (
289+
remaining.index("on_train_step_end") if "on_train_step_end" in remaining else len(remaining)
290+
)
291+
assert next_step_end < next_step_start or next_step_start == len(remaining)
292+
293+
# Verify context data availability
294+
for snapshot in tracking_callback.context_snapshots:
295+
assert snapshot["has_state"], f"{snapshot['event']} missing state"
296+
assert snapshot["has_model"], f"{snapshot['event']} missing model"
297+
assert snapshot["has_user_state"], f"{snapshot['event']} missing user_state"
298+
299+
training_events = ["on_train_start", "on_train_step_start", "on_train_step_end", "on_train_end"]
300+
for snapshot in tracking_callback.context_snapshots:
301+
if snapshot["event"] in training_events:
302+
assert snapshot["has_optimizer"], f"{snapshot['event']} missing optimizer"
303+
assert snapshot["has_scheduler"], f"{snapshot['event']} missing scheduler"
304+
305+
for snapshot in tracking_callback.get_snapshots_for_event("on_train_step_end"):
306+
assert snapshot["has_loss_dict"], "on_train_step_end missing loss_dict"
307+
assert snapshot["has_grad_norm"], "on_train_step_end missing grad_norm"
308+
assert snapshot["has_skipped_iter"], "on_train_step_end missing skipped_iter"
309+
310+
for snapshot in tracking_callback.get_snapshots_for_event("on_eval_end"):
311+
assert snapshot["has_total_loss_dict"], "on_eval_end missing total_loss_dict"
312+
313+
for snapshot in tracking_callback.get_snapshots_for_event("on_test_end"):
314+
assert snapshot["has_total_loss_dict"], "on_test_end missing total_loss_dict"
315+
316+
# Verify user_state persistence (UserStateCallback)
317+
assert user_state_callback.final_count == train_iters, (
318+
f"Final counter should be {train_iters}, got {user_state_callback.final_count}"
319+
)
320+
assert user_state_callback.step_values == list(range(1, train_iters + 1)), (
321+
f"Step values should be [1..{train_iters}], got {user_state_callback.step_values}"
322+
)
323+
# In-training eval happens after step 5, counter should be 5
324+
# Post-training validation reads counter=8 (final train_iters)
325+
assert user_state_callback.eval_read_values[0] == eval_interval, (
326+
f"First eval should read counter={eval_interval}, got {user_state_callback.eval_read_values[0]}"
327+
)
328+
assert user_state_callback.eval_read_values[-1] == train_iters, (
329+
f"Post-training eval should read counter={train_iters}, got {user_state_callback.eval_read_values[-1]}"
330+
)
331+
assert len(user_state_callback.eval_read_values) == expected_eval_runs, (
332+
f"Should have {expected_eval_runs} eval reads, got {len(user_state_callback.eval_read_values)}"
333+
)
334+
# Post-training test runs after training, reads counter=8
335+
assert len(user_state_callback.test_read_values) == expected_test_runs, (
336+
f"Should have {expected_test_runs} test reads, got {len(user_state_callback.test_read_values)}"
337+
)
338+
assert user_state_callback.test_read_values[0] == train_iters, (
339+
f"Test should read counter={train_iters}, got {user_state_callback.test_read_values[0]}"
340+
)
341+
342+
# Verify functional callbacks fired
343+
assert functional_log[0] == "fn_start", "Functional on_train_start should fire"
344+
assert functional_log[-1] == "fn_end", "Functional on_train_end should fire"
345+
assert functional_log.count("fn_step") == train_iters, (
346+
f"Functional on_train_step_end should fire {train_iters} times"
347+
)

0 commit comments

Comments
 (0)