Skip to content

Commit 2c99674

Browse files
committed
add testing
1 parent f4ee1c7 commit 2c99674

File tree

1 file changed

+113
-0
lines changed

1 file changed

+113
-0
lines changed

tests/tests_pytorch/callbacks/test_throughput_monitor.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,3 +307,116 @@ def test_throughput_monitor_eval(tmp_path, fn):
307307
call(metrics={**expected, f"{fn}|batches": 9, f"{fn}|samples": 27}, step=9),
308308
call(metrics={**expected, f"{fn}|batches": 12, f"{fn}|samples": 36}, step=12),
309309
]
310+
311+
312+
def test_throughput_monitor_variable_batch_size(tmp_path):
313+
"""Test that ThroughputMonitor correctly handles variable batch sizes."""
314+
logger_mock = Mock()
315+
logger_mock.save_dir = tmp_path
316+
317+
# Simulate variable batch sizes by tracking calls
318+
batch_sizes = [1, 3, 2, 1, 4]
319+
call_count = [0]
320+
321+
def variable_batch_size_fn(batch):
322+
# Return the predefined batch size for this call
323+
current_batch_size = batch_sizes[call_count[0] % len(batch_sizes)]
324+
call_count[0] += 1
325+
return current_batch_size
326+
327+
monitor = ThroughputMonitor(batch_size_fn=variable_batch_size_fn, window_size=5, separator="|")
328+
329+
model = BoringModel()
330+
model.flops_per_batch = 10
331+
332+
trainer = Trainer(
333+
devices=1,
334+
logger=logger_mock,
335+
callbacks=monitor,
336+
max_steps=len(batch_sizes),
337+
log_every_n_steps=1,
338+
limit_val_batches=0,
339+
num_sanity_val_steps=0,
340+
enable_checkpointing=False,
341+
enable_model_summary=False,
342+
enable_progress_bar=False,
343+
)
344+
345+
timings = [0.0] + [i * 0.1 for i in range(1, len(batch_sizes) + 1)]
346+
347+
with (
348+
mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100),
349+
mock.patch("time.perf_counter", side_effect=timings),
350+
):
351+
trainer.fit(model)
352+
353+
log_calls = logger_mock.log_metrics.call_args_list
354+
assert len(log_calls) == len(batch_sizes)
355+
356+
# Expected cumulative samples: 1, 4 (1+3), 6 (4+2), 7 (6+1), 11 (7+4)
357+
expected_cumulative_samples = [1, 4, 6, 7, 11]
358+
359+
for i, log_call in enumerate(log_calls):
360+
metrics = log_call.kwargs["metrics"] if "metrics" in log_call.kwargs else log_call.args[0]
361+
expected_samples = expected_cumulative_samples[i]
362+
assert metrics["train|samples"] == expected_samples, (
363+
f"Step {i}: expected {expected_samples}, got {metrics['train|samples']}"
364+
)
365+
assert metrics["train|batches"] == i + 1, f"Step {i}: expected batches {i + 1}, got {metrics['train|batches']}"
366+
367+
368+
def test_throughput_monitor_variable_batch_size_with_validation(tmp_path):
369+
"""Test variable batch sizes with validation to ensure stage isolation."""
370+
logger_mock = Mock()
371+
logger_mock.save_dir = tmp_path
372+
373+
train_batch_sizes = [2, 1, 3]
374+
val_batch_sizes = [1, 2]
375+
train_call_count = [0]
376+
val_call_count = [0]
377+
378+
def variable_batch_size_fn(batch):
379+
if hasattr(batch, "size") and batch.size(0) > 0:
380+
if train_call_count[0] < len(train_batch_sizes):
381+
current_batch_size = train_batch_sizes[train_call_count[0]]
382+
train_call_count[0] += 1
383+
return current_batch_size
384+
current_batch_size = val_batch_sizes[val_call_count[0] % len(val_batch_sizes)]
385+
val_call_count[0] += 1
386+
return current_batch_size
387+
return 1
388+
389+
monitor = ThroughputMonitor(batch_size_fn=variable_batch_size_fn, window_size=3)
390+
model = BoringModel()
391+
392+
trainer = Trainer(
393+
devices=1,
394+
logger=logger_mock,
395+
callbacks=monitor,
396+
max_steps=len(train_batch_sizes),
397+
log_every_n_steps=1,
398+
limit_val_batches=2,
399+
val_check_interval=2,
400+
num_sanity_val_steps=0,
401+
enable_checkpointing=False,
402+
enable_model_summary=False,
403+
enable_progress_bar=False,
404+
)
405+
406+
with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100):
407+
trainer.fit(model)
408+
409+
# Verify that both training and validation metrics were logged
410+
log_calls = logger_mock.log_metrics.call_args_list
411+
train_calls = [call for call in log_calls if "train/" in str(call) or "train|" in str(call)]
412+
val_calls = [call for call in log_calls if "validate/" in str(call) or "validate|" in str(call)]
413+
414+
assert len(train_calls) > 0, "Expected training metrics to be logged"
415+
assert len(val_calls) > 0, "Expected validation metrics to be logged"
416+
train_samples = []
417+
for train_call in train_calls:
418+
metrics = train_call.kwargs.get("metrics", train_call.args[0] if train_call.args else {})
419+
if "train/samples" in metrics:
420+
train_samples.append(metrics["train/samples"])
421+
elif "train|samples" in metrics:
422+
train_samples.append(metrics["train|samples"])

0 commit comments

Comments
 (0)