2525
2626from ci_tests import SyntheticXPointDataset , test_checkpoint_functionality
2727
28+ # Import benchmark module
29+ from benchmark import TrainingBenchmark
30+
2831def expand_xpoints_mask (binary_mask , kernel_size = 9 ):
2932 """
3033 Expands each X-point in a binary mask to include surrounding cells
@@ -481,11 +484,17 @@ def forward(self, inputs, targets):
481484 return 1.0 - dice
482485
483486# TRAIN & VALIDATION UTILS
484- def train_one_epoch (model , loader , criterion , optimizer , device , scaler , use_amp , amp_dtype ):
487+ def train_one_epoch (model , loader , criterion , optimizer , device , scaler , use_amp , amp_dtype , benchmark = None ):
485488 model .train ()
486489 running_loss = 0.0
487490
491+ # Start epoch timing for benchmark
492+ if benchmark :
493+ benchmark .start_epoch ()
494+
488495 for batch in loader :
496+ batch_start = timer ()
497+
489498 all_data , mask = batch ["all" ].to (device ), batch ["mask" ].to (device )
490499
491500 with autocast (device_type = 'cuda' , dtype = amp_dtype , enabled = use_amp ):
@@ -514,6 +523,15 @@ def train_one_epoch(model, loader, criterion, optimizer, device, scaler, use_amp
514523 optimizer .step ()
515524
516525 running_loss += loss .item ()
526+
527+ # Record batch timing for benchmark
528+ if benchmark :
529+ batch_time = timer () - batch_start
530+ benchmark .record_batch (all_data .size (0 ), batch_time )
531+
532+ # End epoch timing for benchmark
533+ if benchmark :
534+ benchmark .end_epoch ()
517535
518536 return running_loss / len (loader ) if len (loader ) > 0 else 0.0
519537
@@ -672,20 +690,20 @@ def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, fi
672690
673691 plt .close ()
674692
675- def plot_training_history (train_losses , val_losses , save_path = 'plots/training_history.png' ):
693+ def plot_training_history (train_losses , val_loss , save_path = 'plots/training_history.png' ):
676694 """
677695 Plots training and validation losses across epochs.
678696
679697 Parameters:
680698 train_losses (list): List of training losses for each epoch
681- val_losses (list): List of validation losses for each epoch
699+ val_loss (list): List of validation losses for each epoch
682700 save_path (str): Path to save the resulting plot
683701 """
684702 plt .figure (figsize = (10 , 6 ))
685703 epochs = range (1 , len (train_losses ) + 1 )
686704
687705 plt .plot (epochs , train_losses , 'b-' , label = 'Training Loss' )
688- plt .plot (epochs , val_losses , 'r-' , label = 'Validation Loss' )
706+ plt .plot (epochs , val_loss , 'r-' , label = 'Validation Loss' )
689707
690708 plt .title ('Training and Validation Loss' )
691709 plt .xlabel ('Epochs' )
@@ -695,8 +713,8 @@ def plot_training_history(train_losses, val_losses, save_path='plots/training_hi
695713 plt .grid (True , linestyle = '--' , alpha = 0.7 )
696714
697715 # Add some padding to y-axis to make visualization clearer
698- ymin = min (min (train_losses ), min (val_losses )) * 0.9
699- ymax = max (max (train_losses ), max (val_losses )) * 1.1
716+ ymin = min (min (train_losses ), min (val_loss )) * 0.9
717+ ymax = max (max (train_losses ), max (val_loss )) * 1.1
700718 plt .ylim (ymin , ymax )
701719
702720 plt .savefig (save_path , dpi = 300 )
@@ -746,6 +764,10 @@ def parseCommandLineArgs():
746764 choices = ['float16' , 'bfloat16' ], help = 'data type for mixed precision (bfloat16 recommended)' )
747765 parser .add_argument ('--patience' , type = int , default = 15 ,
748766 help = 'patience for early stopping (default: 15)' )
767+ parser .add_argument ('--benchmark' , action = 'store_true' ,
768+ help = 'enable performance benchmarking (tracks timing, throughput, GPU memory)' )
769+ parser .add_argument ('--benchmark-output' , type = Path , default = './benchmark_results.json' ,
770+ help = 'path to save benchmark results JSON file (default: ./benchmark_results.json)' )
749771
750772 # CI TEST: Add smoke test flag
751773 parser .add_argument ('--smoke-test' , action = 'store_true' ,
@@ -974,6 +996,11 @@ def main():
974996 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
975997 print (f"Using device: { device } " )
976998
999+ # Initialize benchmark tracker
1000+ benchmark = TrainingBenchmark (device , enabled = args .benchmark )
1001+ if args .benchmark :
1002+ benchmark .print_hardware_info ()
1003+
9771004 # Use the improved model
9781005 model = UNet (input_channels = 4 , base_channels = 32 ).to (device )
9791006
@@ -1028,14 +1055,21 @@ def main():
10281055
10291056 num_epochs = args .epochs
10301057 for epoch in range (start_epoch , num_epochs ):
1031- train_loss_epoch = train_one_epoch (model , train_loader , criterion , optimizer , device , scaler , use_amp , amp_dtype )
1058+ train_loss_epoch = train_one_epoch (model , train_loader , criterion , optimizer , device , scaler , use_amp , amp_dtype , benchmark )
10321059 val_loss_epoch = validate_one_epoch (model , val_loader , criterion , device , use_amp , amp_dtype )
10331060
10341061 train_loss .append (train_loss_epoch )
10351062 val_loss .append (val_loss_epoch )
10361063
10371064 current_lr = optimizer .param_groups [0 ]['lr' ]
1038- print (f"[Epoch { epoch + 1 } /{ num_epochs } ] LR={ current_lr :.2e} TrainLoss={ train_loss [- 1 ]:.6f} ValLoss={ val_loss [- 1 ]:.6f} " )
1065+
1066+ # Enhanced logging with benchmark metrics
1067+ log_msg = f"[Epoch { epoch + 1 } /{ num_epochs } ] LR={ current_lr :.2e} TrainLoss={ train_loss [- 1 ]:.6f} ValLoss={ val_loss [- 1 ]:.6f} "
1068+ if args .benchmark :
1069+ throughput = benchmark .get_throughput ()
1070+ gpu_mem = benchmark .get_gpu_memory_usage ()
1071+ log_msg += f" | Throughput={ throughput :.2f} samples/s | GPU Mem={ gpu_mem :.2f} GB"
1072+ print (log_msg )
10391073
10401074 # Learning rate scheduling
10411075 scheduler .step ()
@@ -1059,8 +1093,12 @@ def main():
10591093 print (f"Early stopping triggered after { epoch + 1 } epochs (patience={ args .patience } )" )
10601094 break
10611095
1062- plot_training_history (train_loss , val_loss )
1096+ plot_training_history (train_loss , val_loss , save_path = 'plots/training_history.png' )
10631097 print ("time (s) to train model: " + str (timer ()- t2 ))
1098+
1099+ # Print and save benchmark summary
1100+ if args .benchmark :
1101+ benchmark .print_summary (output_file = args .benchmark_output )
10641102
10651103 # CI TEST: Run additional tests if in smoke test mode
10661104 if args .smoke_test :
@@ -1139,7 +1177,7 @@ def main():
11391177 outDir = "plots"
11401178 interpFac = 1
11411179
1142- # Evaluate on combined set for demonstration. Exam this part to see if save to remove
1180+ # Evaluate on combined set for demonstration
11431181 if not args .smoke_test :
11441182 train_fnums = range (args .trainFrameFirst , args .trainFrameLast )
11451183 val_fnums = range (args .validationFrameFirst , args .validationFrameLast )
@@ -1175,10 +1213,6 @@ def main():
11751213
11761214 pred_mask_bin = (pred_prob_np [0 ,0 ] > 0.5 ).astype (np .float32 )
11771215
1178- print (f"Frame { fnum } rotation { rotation } reflectionAxis { reflectionAxis } :" )
1179- print (f" Probabilities - min: { pred_prob_np .min ():.5f} , max: { pred_prob_np .max ():.5f} , mean: { pred_prob_np .mean ():.5f} " )
1180- print (f" Binary Mask (X-points) - count of 1s: { np .sum (pred_mask_bin )} / { pred_mask_bin .size } pixels" )
1181-
11821216 if args .plot :
11831217 # Plot GROUND TRUTH
11841218 plot_psi_contours_and_xpoints (
0 commit comments