@@ -307,3 +307,116 @@ def test_throughput_monitor_eval(tmp_path, fn):
307
307
call (metrics = {** expected , f"{ fn } |batches" : 9 , f"{ fn } |samples" : 27 }, step = 9 ),
308
308
call (metrics = {** expected , f"{ fn } |batches" : 12 , f"{ fn } |samples" : 36 }, step = 12 ),
309
309
]
310
+
311
+
312
+ def test_throughput_monitor_variable_batch_size (tmp_path ):
313
+ """Test that ThroughputMonitor correctly handles variable batch sizes."""
314
+ logger_mock = Mock ()
315
+ logger_mock .save_dir = tmp_path
316
+
317
+ # Simulate variable batch sizes by tracking calls
318
+ batch_sizes = [1 , 3 , 2 , 1 , 4 ]
319
+ call_count = [0 ]
320
+
321
+ def variable_batch_size_fn (batch ):
322
+ # Return the predefined batch size for this call
323
+ current_batch_size = batch_sizes [call_count [0 ] % len (batch_sizes )]
324
+ call_count [0 ] += 1
325
+ return current_batch_size
326
+
327
+ monitor = ThroughputMonitor (batch_size_fn = variable_batch_size_fn , window_size = 5 , separator = "|" )
328
+
329
+ model = BoringModel ()
330
+ model .flops_per_batch = 10
331
+
332
+ trainer = Trainer (
333
+ devices = 1 ,
334
+ logger = logger_mock ,
335
+ callbacks = monitor ,
336
+ max_steps = len (batch_sizes ),
337
+ log_every_n_steps = 1 ,
338
+ limit_val_batches = 0 ,
339
+ num_sanity_val_steps = 0 ,
340
+ enable_checkpointing = False ,
341
+ enable_model_summary = False ,
342
+ enable_progress_bar = False ,
343
+ )
344
+
345
+ timings = [0.0 ] + [i * 0.1 for i in range (1 , len (batch_sizes ) + 1 )]
346
+
347
+ with (
348
+ mock .patch ("lightning.pytorch.callbacks.throughput_monitor.get_available_flops" , return_value = 100 ),
349
+ mock .patch ("time.perf_counter" , side_effect = timings ),
350
+ ):
351
+ trainer .fit (model )
352
+
353
+ log_calls = logger_mock .log_metrics .call_args_list
354
+ assert len (log_calls ) == len (batch_sizes )
355
+
356
+ # Expected cumulative samples: 1, 4 (1+3), 6 (4+2), 7 (6+1), 11 (7+4)
357
+ expected_cumulative_samples = [1 , 4 , 6 , 7 , 11 ]
358
+
359
+ for i , log_call in enumerate (log_calls ):
360
+ metrics = log_call .kwargs ["metrics" ] if "metrics" in log_call .kwargs else log_call .args [0 ]
361
+ expected_samples = expected_cumulative_samples [i ]
362
+ assert metrics ["train|samples" ] == expected_samples , (
363
+ f"Step { i } : expected { expected_samples } , got { metrics ['train|samples' ]} "
364
+ )
365
+ assert metrics ["train|batches" ] == i + 1 , f"Step { i } : expected batches { i + 1 } , got { metrics ['train|batches' ]} "
366
+
367
+
368
+ def test_throughput_monitor_variable_batch_size_with_validation (tmp_path ):
369
+ """Test variable batch sizes with validation to ensure stage isolation."""
370
+ logger_mock = Mock ()
371
+ logger_mock .save_dir = tmp_path
372
+
373
+ train_batch_sizes = [2 , 1 , 3 ]
374
+ val_batch_sizes = [1 , 2 ]
375
+ train_call_count = [0 ]
376
+ val_call_count = [0 ]
377
+
378
+ def variable_batch_size_fn (batch ):
379
+ if hasattr (batch , "size" ) and batch .size (0 ) > 0 :
380
+ if train_call_count [0 ] < len (train_batch_sizes ):
381
+ current_batch_size = train_batch_sizes [train_call_count [0 ]]
382
+ train_call_count [0 ] += 1
383
+ return current_batch_size
384
+ current_batch_size = val_batch_sizes [val_call_count [0 ] % len (val_batch_sizes )]
385
+ val_call_count [0 ] += 1
386
+ return current_batch_size
387
+ return 1
388
+
389
+ monitor = ThroughputMonitor (batch_size_fn = variable_batch_size_fn , window_size = 3 )
390
+ model = BoringModel ()
391
+
392
+ trainer = Trainer (
393
+ devices = 1 ,
394
+ logger = logger_mock ,
395
+ callbacks = monitor ,
396
+ max_steps = len (train_batch_sizes ),
397
+ log_every_n_steps = 1 ,
398
+ limit_val_batches = 2 ,
399
+ val_check_interval = 2 ,
400
+ num_sanity_val_steps = 0 ,
401
+ enable_checkpointing = False ,
402
+ enable_model_summary = False ,
403
+ enable_progress_bar = False ,
404
+ )
405
+
406
+ with mock .patch ("lightning.pytorch.callbacks.throughput_monitor.get_available_flops" , return_value = 100 ):
407
+ trainer .fit (model )
408
+
409
+ # Verify that both training and validation metrics were logged
410
+ log_calls = logger_mock .log_metrics .call_args_list
411
+ train_calls = [call for call in log_calls if "train/" in str (call ) or "train|" in str (call )]
412
+ val_calls = [call for call in log_calls if "validate/" in str (call ) or "validate|" in str (call )]
413
+
414
+ assert len (train_calls ) > 0 , "Expected training metrics to be logged"
415
+ assert len (val_calls ) > 0 , "Expected validation metrics to be logged"
416
+ train_samples = []
417
+ for train_call in train_calls :
418
+ metrics = train_call .kwargs .get ("metrics" , train_call .args [0 ] if train_call .args else {})
419
+ if "train/samples" in metrics :
420
+ train_samples .append (metrics ["train/samples" ])
421
+ elif "train|samples" in metrics :
422
+ train_samples .append (metrics ["train|samples" ])
0 commit comments