Skip to content

Commit 93378a7

Browse files
committed
Add performance profiling for training runs
1 parent f99434f commit 93378a7

File tree

2 files changed

+336
-14
lines changed

2 files changed

+336
-14
lines changed

XPointMLTest.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525

2626
from ci_tests import SyntheticXPointDataset, test_checkpoint_functionality
2727

28+
# Import benchmark module
29+
from benchmark import TrainingBenchmark
30+
2831
def expand_xpoints_mask(binary_mask, kernel_size=9):
2932
"""
3033
Expands each X-point in a binary mask to include surrounding cells
@@ -481,11 +484,17 @@ def forward(self, inputs, targets):
481484
return 1.0 - dice
482485

483486
# TRAIN & VALIDATION UTILS
484-
def train_one_epoch(model, loader, criterion, optimizer, device, scaler, use_amp, amp_dtype):
487+
def train_one_epoch(model, loader, criterion, optimizer, device, scaler, use_amp, amp_dtype, benchmark=None):
485488
model.train()
486489
running_loss = 0.0
487490

491+
# Start epoch timing for benchmark
492+
if benchmark:
493+
benchmark.start_epoch()
494+
488495
for batch in loader:
496+
batch_start = timer()
497+
489498
all_data, mask = batch["all"].to(device), batch["mask"].to(device)
490499

491500
with autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp):
@@ -514,6 +523,15 @@ def train_one_epoch(model, loader, criterion, optimizer, device, scaler, use_amp
514523
optimizer.step()
515524

516525
running_loss += loss.item()
526+
527+
# Record batch timing for benchmark
528+
if benchmark:
529+
batch_time = timer() - batch_start
530+
benchmark.record_batch(all_data.size(0), batch_time)
531+
532+
# End epoch timing for benchmark
533+
if benchmark:
534+
benchmark.end_epoch()
517535

518536
return running_loss / len(loader) if len(loader) > 0 else 0.0
519537

@@ -672,20 +690,20 @@ def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, fi
672690

673691
plt.close()
674692

675-
def plot_training_history(train_losses, val_losses, save_path='plots/training_history.png'):
693+
def plot_training_history(train_losses, val_loss, save_path='plots/training_history.png'):
676694
"""
677695
Plots training and validation losses across epochs.
678696
679697
Parameters:
680698
train_losses (list): List of training losses for each epoch
681-
val_losses (list): List of validation losses for each epoch
699+
val_loss (list): List of validation losses for each epoch
682700
save_path (str): Path to save the resulting plot
683701
"""
684702
plt.figure(figsize=(10, 6))
685703
epochs = range(1, len(train_losses) + 1)
686704

687705
plt.plot(epochs, train_losses, 'b-', label='Training Loss')
688-
plt.plot(epochs, val_losses, 'r-', label='Validation Loss')
706+
plt.plot(epochs, val_loss, 'r-', label='Validation Loss')
689707

690708
plt.title('Training and Validation Loss')
691709
plt.xlabel('Epochs')
@@ -695,8 +713,8 @@ def plot_training_history(train_losses, val_losses, save_path='plots/training_hi
695713
plt.grid(True, linestyle='--', alpha=0.7)
696714

697715
# Add some padding to y-axis to make visualization clearer
698-
ymin = min(min(train_losses), min(val_losses)) * 0.9
699-
ymax = max(max(train_losses), max(val_losses)) * 1.1
716+
ymin = min(min(train_losses), min(val_loss)) * 0.9
717+
ymax = max(max(train_losses), max(val_loss)) * 1.1
700718
plt.ylim(ymin, ymax)
701719

702720
plt.savefig(save_path, dpi=300)
@@ -746,6 +764,10 @@ def parseCommandLineArgs():
746764
choices=['float16', 'bfloat16'], help='data type for mixed precision (bfloat16 recommended)')
747765
parser.add_argument('--patience', type=int, default=15,
748766
help='patience for early stopping (default: 15)')
767+
parser.add_argument('--benchmark', action='store_true',
768+
help='enable performance benchmarking (tracks timing, throughput, GPU memory)')
769+
parser.add_argument('--benchmark-output', type=Path, default='./benchmark_results.json',
770+
help='path to save benchmark results JSON file (default: ./benchmark_results.json)')
749771

750772
# CI TEST: Add smoke test flag
751773
parser.add_argument('--smoke-test', action='store_true',
@@ -974,6 +996,11 @@ def main():
974996
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
975997
print(f"Using device: {device}")
976998

999+
# Initialize benchmark tracker
1000+
benchmark = TrainingBenchmark(device, enabled=args.benchmark)
1001+
if args.benchmark:
1002+
benchmark.print_hardware_info()
1003+
9771004
# Use the improved model
9781005
model = UNet(input_channels=4, base_channels=32).to(device)
9791006

@@ -1028,14 +1055,21 @@ def main():
10281055

10291056
num_epochs = args.epochs
10301057
for epoch in range(start_epoch, num_epochs):
1031-
train_loss_epoch = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler, use_amp, amp_dtype)
1058+
train_loss_epoch = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler, use_amp, amp_dtype, benchmark)
10321059
val_loss_epoch = validate_one_epoch(model, val_loader, criterion, device, use_amp, amp_dtype)
10331060

10341061
train_loss.append(train_loss_epoch)
10351062
val_loss.append(val_loss_epoch)
10361063

10371064
current_lr = optimizer.param_groups[0]['lr']
1038-
print(f"[Epoch {epoch+1}/{num_epochs}] LR={current_lr:.2e} TrainLoss={train_loss[-1]:.6f} ValLoss={val_loss[-1]:.6f}")
1065+
1066+
# Enhanced logging with benchmark metrics
1067+
log_msg = f"[Epoch {epoch+1}/{num_epochs}] LR={current_lr:.2e} TrainLoss={train_loss[-1]:.6f} ValLoss={val_loss[-1]:.6f}"
1068+
if args.benchmark:
1069+
throughput = benchmark.get_throughput()
1070+
gpu_mem = benchmark.get_gpu_memory_usage()
1071+
log_msg += f" | Throughput={throughput:.2f} samples/s | GPU Mem={gpu_mem:.2f} GB"
1072+
print(log_msg)
10391073

10401074
# Learning rate scheduling
10411075
scheduler.step()
@@ -1059,8 +1093,12 @@ def main():
10591093
print(f"Early stopping triggered after {epoch+1} epochs (patience={args.patience})")
10601094
break
10611095

1062-
plot_training_history(train_loss, val_loss)
1096+
plot_training_history(train_loss, val_loss, save_path='plots/training_history.png')
10631097
print("time (s) to train model: " + str(timer()-t2))
1098+
1099+
# Print and save benchmark summary
1100+
if args.benchmark:
1101+
benchmark.print_summary(output_file=args.benchmark_output)
10641102

10651103
# CI TEST: Run additional tests if in smoke test mode
10661104
if args.smoke_test:
@@ -1139,7 +1177,7 @@ def main():
11391177
outDir = "plots"
11401178
interpFac = 1
11411179

1142-
# Evaluate on combined set for demonstration. Exam this part to see if save to remove
1180+
# Evaluate on combined set for demonstration
11431181
if not args.smoke_test:
11441182
train_fnums = range(args.trainFrameFirst, args.trainFrameLast)
11451183
val_fnums = range(args.validationFrameFirst, args.validationFrameLast)
@@ -1175,10 +1213,6 @@ def main():
11751213

11761214
pred_mask_bin = (pred_prob_np[0,0] > 0.5).astype(np.float32)
11771215

1178-
print(f"Frame {fnum} rotation {rotation} reflectionAxis {reflectionAxis}:")
1179-
print(f" Probabilities - min: {pred_prob_np.min():.5f}, max: {pred_prob_np.max():.5f}, mean: {pred_prob_np.mean():.5f}")
1180-
print(f" Binary Mask (X-points) - count of 1s: {np.sum(pred_mask_bin)} / {pred_mask_bin.size} pixels")
1181-
11821216
if args.plot:
11831217
# Plot GROUND TRUTH
11841218
plot_psi_contours_and_xpoints(

0 commit comments

Comments
 (0)