Skip to content

Commit 3635d7b

Browse files
authored
Merge pull request #17 from SCOREC/performance-profiling
Benchmark and Model Evaluation Metrics
2 parents 68bd475 + 2c0277a commit 3635d7b

File tree

4 files changed

+719
-15
lines changed

4 files changed

+719
-15
lines changed

README.md

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ The classifier supports several command line options for training configuration:
108108
- `--plotDir`: Directory where figures are written (default: `./plots`)
109109
- `--checkPointFrequency`: Number of epochs between model checkpoints (default: 10)
110110

111+
### Performance Benchmarking
112+
- `--benchmark`: Enable performance benchmarking (tracks timing, throughput, GPU memory)
113+
- `--benchmark-output`: Path to save benchmark results JSON file (default: `./benchmark_results.json`)
114+
- `--eval-output`: Path to save evaluation metrics JSON file (default: `./evaluation_metrics.json`)
115+
111116
### Testing
112117
- `--smoke-test`: Run minimal smoke test for CI (overrides other parameters for quick validation)
113118

@@ -138,4 +143,51 @@ The following commands should be run on `checkers` **every time you create a new
138143
cd nsfCssiMlClassifier
139144
source envPyTorch.sh
140145
source pgkyl/bin/activate
141-
```
146+
```
147+
148+
## Model Evaluation Metrics
149+
150+
The model evaluation system measures how well the classifier identifies X-points (magnetic reconnection sites) by treating it as a pixel-level binary classification problem.
151+
152+
### Key Metrics
153+
154+
The evaluation outputs several metrics saved to JSON files:
155+
156+
- **Accuracy**: Overall pixel classification correctness (can be misleading due to class imbalance)
157+
- **Precision**: Fraction of detected X-points that are correct (measures false alarm rate)
158+
- **Recall**: Fraction of actual X-points that were found (measures miss rate)
159+
- **F1 Score**: Harmonic mean of precision and recall (balanced performance metric)
160+
- **IoU**: Intersection over Union - spatial overlap quality between predicted and actual X-point regions
161+
162+
### Understanding the Results
163+
164+
**Good performance indicators:**
165+
- F1 Score > 0.8
166+
- IoU > 0.5
167+
- Similar metrics between training and validation sets (no overfitting)
168+
- Low standard deviation across frames (consistent performance)
169+
170+
**Warning signs:**
171+
- Large gap between training and validation metrics (overfitting)
172+
- High precision but low recall (too conservative, missing X-points)
173+
- Low precision but high recall (too aggressive, many false alarms)
174+
- High frame-to-frame variation (inconsistent detection)
175+
176+
### Output Files
177+
178+
After training, the model produces:
179+
- `evaluation_metrics.json`: Validation set performance
180+
- `train_evaluation_metrics.json`: Training set performance
181+
- Performance plots in the `plots/` directory showing:
182+
- Training history (loss curves)
183+
- Model predictions vs ground truth
184+
- True positives (green), false positives (red), false negatives (yellow)
185+
186+
### Physics Context
187+
188+
For reconnection studies:
189+
- **High recall is critical**: Missing X-points means missing reconnection events
190+
- **Precision affects analysis**: False positives corrupt downstream calculations
191+
- **IoU indicates localization**: Poor IoU means inaccurate X-point positions
192+
193+
The model uses a 9×9 pixel expansion around X-points to account for localization uncertainty while still requiring accurate region identification.

XPointMLTest.py

Lines changed: 112 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@
2525

2626
from ci_tests import SyntheticXPointDataset, test_checkpoint_functionality
2727

28+
# Import benchmark module
29+
from benchmark import TrainingBenchmark
30+
31+
# Import evaluation metrics module
32+
from eval_metrics import ModelEvaluator, evaluate_model_on_dataset
33+
2834
def expand_xpoints_mask(binary_mask, kernel_size=9):
2935
"""
3036
Expands each X-point in a binary mask to include surrounding cells
@@ -481,11 +487,17 @@ def forward(self, inputs, targets):
481487
return 1.0 - dice
482488

483489
# TRAIN & VALIDATION UTILS
484-
def train_one_epoch(model, loader, criterion, optimizer, device, scaler, use_amp, amp_dtype):
490+
def train_one_epoch(model, loader, criterion, optimizer, device, scaler, use_amp, amp_dtype, benchmark=None):
485491
model.train()
486492
running_loss = 0.0
487493

494+
# Start epoch timing for benchmark
495+
if benchmark:
496+
benchmark.start_epoch()
497+
488498
for batch in loader:
499+
batch_start = timer()
500+
489501
all_data, mask = batch["all"].to(device), batch["mask"].to(device)
490502

491503
with autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp):
@@ -514,6 +526,15 @@ def train_one_epoch(model, loader, criterion, optimizer, device, scaler, use_amp
514526
optimizer.step()
515527

516528
running_loss += loss.item()
529+
530+
# Record batch timing for benchmark
531+
if benchmark:
532+
batch_time = timer() - batch_start
533+
benchmark.record_batch(all_data.size(0), batch_time)
534+
535+
# End epoch timing for benchmark
536+
if benchmark:
537+
benchmark.end_epoch()
517538

518539
return running_loss / len(loader) if len(loader) > 0 else 0.0
519540

@@ -672,20 +693,20 @@ def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, fi
672693

673694
plt.close()
674695

675-
def plot_training_history(train_losses, val_losses, save_path='plots/training_history.png'):
696+
def plot_training_history(train_losses, val_loss, save_path='plots/training_history.png'):
676697
"""
677698
Plots training and validation losses across epochs.
678699
679700
Parameters:
680701
train_losses (list): List of training losses for each epoch
681-
val_losses (list): List of validation losses for each epoch
702+
val_loss (list): List of validation losses for each epoch
682703
save_path (str): Path to save the resulting plot
683704
"""
684705
plt.figure(figsize=(10, 6))
685706
epochs = range(1, len(train_losses) + 1)
686707

687708
plt.plot(epochs, train_losses, 'b-', label='Training Loss')
688-
plt.plot(epochs, val_losses, 'r-', label='Validation Loss')
709+
plt.plot(epochs, val_loss, 'r-', label='Validation Loss')
689710

690711
plt.title('Training and Validation Loss')
691712
plt.xlabel('Epochs')
@@ -695,8 +716,8 @@ def plot_training_history(train_losses, val_losses, save_path='plots/training_hi
695716
plt.grid(True, linestyle='--', alpha=0.7)
696717

697718
# Add some padding to y-axis to make visualization clearer
698-
ymin = min(min(train_losses), min(val_losses)) * 0.9
699-
ymax = max(max(train_losses), max(val_losses)) * 1.1
719+
ymin = min(min(train_losses), min(val_loss)) * 0.9
720+
ymax = max(max(train_losses), max(val_loss)) * 1.1
700721
plt.ylim(ymin, ymax)
701722

702723
plt.savefig(save_path, dpi=300)
@@ -746,6 +767,12 @@ def parseCommandLineArgs():
746767
choices=['float16', 'bfloat16'], help='data type for mixed precision (bfloat16 recommended)')
747768
parser.add_argument('--patience', type=int, default=15,
748769
help='patience for early stopping (default: 15)')
770+
parser.add_argument('--benchmark', action='store_true',
771+
help='enable performance benchmarking (tracks timing, throughput, GPU memory)')
772+
parser.add_argument('--benchmark-output', type=Path, default='./benchmark_results.json',
773+
help='path to save benchmark results JSON file (default: ./benchmark_results.json)')
774+
parser.add_argument('--eval-output', type=Path, default='./evaluation_metrics.json',
775+
help='path to save evaluation metrics JSON file (default: ./evaluation_metrics.json)')
749776

750777
# CI TEST: Add smoke test flag
751778
parser.add_argument('--smoke-test', action='store_true',
@@ -974,6 +1001,11 @@ def main():
9741001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9751002
print(f"Using device: {device}")
9761003

1004+
# Initialize benchmark tracker
1005+
benchmark = TrainingBenchmark(device, enabled=args.benchmark)
1006+
if args.benchmark:
1007+
benchmark.print_hardware_info()
1008+
9771009
# Use the improved model
9781010
model = UNet(input_channels=4, base_channels=32).to(device)
9791011

@@ -1028,14 +1060,21 @@ def main():
10281060

10291061
num_epochs = args.epochs
10301062
for epoch in range(start_epoch, num_epochs):
1031-
train_loss_epoch = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler, use_amp, amp_dtype)
1063+
train_loss_epoch = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler, use_amp, amp_dtype, benchmark)
10321064
val_loss_epoch = validate_one_epoch(model, val_loader, criterion, device, use_amp, amp_dtype)
10331065

10341066
train_loss.append(train_loss_epoch)
10351067
val_loss.append(val_loss_epoch)
10361068

10371069
current_lr = optimizer.param_groups[0]['lr']
1038-
print(f"[Epoch {epoch+1}/{num_epochs}] LR={current_lr:.2e} TrainLoss={train_loss[-1]:.6f} ValLoss={val_loss[-1]:.6f}")
1070+
1071+
# Enhanced logging with benchmark metrics
1072+
log_msg = f"[Epoch {epoch+1}/{num_epochs}] LR={current_lr:.2e} TrainLoss={train_loss[-1]:.6f} ValLoss={val_loss[-1]:.6f}"
1073+
if args.benchmark:
1074+
throughput = benchmark.get_throughput()
1075+
gpu_mem = benchmark.get_gpu_memory_usage()
1076+
log_msg += f" | Throughput={throughput:.2f} samples/s | GPU Mem={gpu_mem:.2f} GB"
1077+
print(log_msg)
10391078

10401079
# Learning rate scheduling
10411080
scheduler.step()
@@ -1059,8 +1098,12 @@ def main():
10591098
print(f"Early stopping triggered after {epoch+1} epochs (patience={args.patience})")
10601099
break
10611100

1062-
plot_training_history(train_loss, val_loss)
1101+
plot_training_history(train_loss, val_loss, save_path='plots/training_history.png')
10631102
print("time (s) to train model: " + str(timer()-t2))
1103+
1104+
# Print and save benchmark summary
1105+
if args.benchmark:
1106+
benchmark.print_summary(output_file=args.benchmark_output)
10641107

10651108
# CI TEST: Run additional tests if in smoke test mode
10661109
if args.smoke_test:
@@ -1134,12 +1177,71 @@ def main():
11341177
print("Loading best model for evaluation...")
11351178
model.load_state_dict(torch.load(best_model_path, weights_only=True))
11361179

1180+
# new evaluation code
1181+
# Evaluate model performance
1182+
if not args.smoke_test:
1183+
# print("\n" + "="*70)
1184+
# print("RUNNING MODEL EVALUATION")
1185+
# print("="*70)
1186+
1187+
# # Evaluate on validation set
1188+
# print("\nEvaluating on validation set...")
1189+
val_evaluator = evaluate_model_on_dataset(
1190+
model,
1191+
val_dataset, # Use original dataset, not patch dataset
1192+
device,
1193+
use_amp=use_amp,
1194+
amp_dtype=amp_dtype,
1195+
threshold=0.5
1196+
)
1197+
1198+
# Print and save validation metrics
1199+
val_evaluator.print_summary()
1200+
val_evaluator.save_json(args.eval_output)
1201+
1202+
# Evaluate on training set
1203+
print("\nEvaluating on training set...")
1204+
train_evaluator = evaluate_model_on_dataset(
1205+
model,
1206+
train_dataset,
1207+
device,
1208+
use_amp=use_amp,
1209+
amp_dtype=amp_dtype,
1210+
threshold=0.5
1211+
)
1212+
1213+
# Print and save training metrics
1214+
train_evaluator.print_summary()
1215+
train_eval_path = args.eval_output.parent / f"train_{args.eval_output.name}"
1216+
train_evaluator.save_json(train_eval_path)
1217+
1218+
# Compare training vs validation to check for overfitting
1219+
train_global = train_evaluator.get_global_metrics()
1220+
val_global = val_evaluator.get_global_metrics()
1221+
1222+
print("\n" + "="*70)
1223+
print("OVERFITTING CHECK")
1224+
print("="*70)
1225+
print(f"Training F1: {train_global['f1_score']:.4f}")
1226+
print(f"Validation F1: {val_global['f1_score']:.4f}")
1227+
print(f"Difference: {abs(train_global['f1_score'] - val_global['f1_score']):.4f}")
1228+
1229+
if train_global['f1_score'] - val_global['f1_score'] > 0.05:
1230+
print("⚠ Warning: Possible overfitting detected (train F1 >> val F1)")
1231+
elif val_global['f1_score'] - train_global['f1_score'] > 0.05:
1232+
print("⚠ Warning: Unusual pattern (val F1 >> train F1)")
1233+
else:
1234+
print("✓ Model generalizes well to validation set")
1235+
print("="*70 + "\n")
1236+
1237+
# ==================== END NEW EVALUATION CODE ====================
1238+
11371239
# (D) Plotting after training
11381240
model.eval() # switch to inference mode
11391241
outDir = "plots"
11401242
interpFac = 1
11411243

1142-
# Evaluate on combined set for demonstration. Exam this part to see if save to remove
1244+
# Evaluate on combined set for demonstration
11431245
if not args.smoke_test:
11441246
train_fnums = range(args.trainFrameFirst, args.trainFrameLast)
11451247
val_fnums = range(args.validationFrameFirst, args.validationFrameLast)
@@ -1175,10 +1277,6 @@ def main():
11751277

11761278
pred_mask_bin = (pred_prob_np[0,0] > 0.5).astype(np.float32)
11771279

1178-
print(f"Frame {fnum} rotation {rotation} reflectionAxis {reflectionAxis}:")
1179-
print(f" Probabilities - min: {pred_prob_np.min():.5f}, max: {pred_prob_np.max():.5f}, mean: {pred_prob_np.mean():.5f}")
1180-
print(f" Binary Mask (X-points) - count of 1s: {np.sum(pred_mask_bin)} / {pred_mask_bin.size} pixels")
1181-
11821280
if args.plot:
11831281
# Plot GROUND TRUTH
11841282
plot_psi_contours_and_xpoints(

0 commit comments

Comments
 (0)