diff --git a/README.md b/README.md index 55907d3..4025f26 100644 --- a/README.md +++ b/README.md @@ -108,6 +108,11 @@ The classifier supports several command line options for training configuration: - `--plotDir`: Directory where figures are written (default: `./plots`) - `--checkPointFrequency`: Number of epochs between model checkpoints (default: 10) +### Performance Benchmarking +- `--benchmark`: Enable performance benchmarking (tracks timing, throughput, GPU memory) +- `--benchmark-output`: Path to save benchmark results JSON file (default: `./benchmark_results.json`) +- `--eval-output`: Path to save evaluation metrics JSON file (default: `./evaluation_metrics.json`) + ### Testing - `--smoke-test`: Run minimal smoke test for CI (overrides other parameters for quick validation) @@ -138,4 +143,51 @@ The following commands should be run on `checkers` **every time you create a new cd nsfCssiMlClassifier source envPyTorch.sh source pgkyl/bin/activate -``` \ No newline at end of file +``` + +## Model Evaluation Metrics + +The model evaluation system measures how well the classifier identifies X-points (magnetic reconnection sites) by treating it as a pixel-level binary classification problem. + +### Key Metrics + +The evaluation outputs several metrics saved to JSON files: + +- **Accuracy**: Overall pixel classification correctness (can be misleading due to class imbalance) +- **Precision**: Fraction of detected X-points that are correct (measures false alarm rate) +- **Recall**: Fraction of actual X-points that were found (measures miss rate) +- **F1 Score**: Harmonic mean of precision and recall (balanced performance metric) +- **IoU**: Intersection over Union - spatial overlap quality between predicted and actual X-point regions + +### Understanding the Results + +**Good performance indicators:** +- F1 Score > 0.8 +- IoU > 0.5 +- Similar metrics between training and validation sets (no overfitting) +- Low standard deviation across frames (consistent performance) + +**Warning signs:** +- Large gap between training and validation metrics (overfitting) +- High precision but low recall (too conservative, missing X-points) +- Low precision but high recall (too aggressive, many false alarms) +- High frame-to-frame variation (inconsistent detection) + +### Output Files + +After training, the model produces: +- `evaluation_metrics.json`: Validation set performance +- `train_evaluation_metrics.json`: Training set performance +- Performance plots in the `plots/` directory showing: + - Training history (loss curves) + - Model predictions vs ground truth + - True positives (green), false positives (red), false negatives (yellow) + +### Physics Context + +For reconnection studies: +- **High recall is critical**: Missing X-points means missing reconnection events +- **Precision affects analysis**: False positives corrupt downstream calculations +- **IoU indicates localization**: Poor IoU means inaccurate X-point positions + +The model uses a 9×9 pixel expansion around X-points to account for localization uncertainty while still requiring accurate region identification. \ No newline at end of file diff --git a/XPointMLTest.py b/XPointMLTest.py index 8404d5c..22f1a69 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -25,6 +25,12 @@ from ci_tests import SyntheticXPointDataset, test_checkpoint_functionality +# Import benchmark module +from benchmark import TrainingBenchmark + +# Import evaluation metrics module +from eval_metrics import ModelEvaluator, evaluate_model_on_dataset + def expand_xpoints_mask(binary_mask, kernel_size=9): """ Expands each X-point in a binary mask to include surrounding cells @@ -481,11 +487,17 @@ def forward(self, inputs, targets): return 1.0 - dice # TRAIN & VALIDATION UTILS -def train_one_epoch(model, loader, criterion, optimizer, device, scaler, use_amp, amp_dtype): +def train_one_epoch(model, loader, criterion, optimizer, device, scaler, use_amp, amp_dtype, benchmark=None): model.train() running_loss = 0.0 + # Start epoch timing for benchmark + if benchmark: + benchmark.start_epoch() + for batch in loader: + batch_start = timer() + all_data, mask = batch["all"].to(device), batch["mask"].to(device) with autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp): @@ -514,6 +526,15 @@ def train_one_epoch(model, loader, criterion, optimizer, device, scaler, use_amp optimizer.step() running_loss += loss.item() + + # Record batch timing for benchmark + if benchmark: + batch_time = timer() - batch_start + benchmark.record_batch(all_data.size(0), batch_time) + + # End epoch timing for benchmark + if benchmark: + benchmark.end_epoch() return running_loss / len(loader) if len(loader) > 0 else 0.0 @@ -672,20 +693,20 @@ def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, fi plt.close() -def plot_training_history(train_losses, val_losses, save_path='plots/training_history.png'): +def plot_training_history(train_losses, val_loss, save_path='plots/training_history.png'): """ Plots training and validation losses across epochs. Parameters: train_losses (list): List of training losses for each epoch - val_losses (list): List of validation losses for each epoch + val_loss (list): List of validation losses for each epoch save_path (str): Path to save the resulting plot """ plt.figure(figsize=(10, 6)) epochs = range(1, len(train_losses) + 1) plt.plot(epochs, train_losses, 'b-', label='Training Loss') - plt.plot(epochs, val_losses, 'r-', label='Validation Loss') + plt.plot(epochs, val_loss, 'r-', label='Validation Loss') plt.title('Training and Validation Loss') plt.xlabel('Epochs') @@ -695,8 +716,8 @@ def plot_training_history(train_losses, val_losses, save_path='plots/training_hi plt.grid(True, linestyle='--', alpha=0.7) # Add some padding to y-axis to make visualization clearer - ymin = min(min(train_losses), min(val_losses)) * 0.9 - ymax = max(max(train_losses), max(val_losses)) * 1.1 + ymin = min(min(train_losses), min(val_loss)) * 0.9 + ymax = max(max(train_losses), max(val_loss)) * 1.1 plt.ylim(ymin, ymax) plt.savefig(save_path, dpi=300) @@ -746,6 +767,12 @@ def parseCommandLineArgs(): choices=['float16', 'bfloat16'], help='data type for mixed precision (bfloat16 recommended)') parser.add_argument('--patience', type=int, default=15, help='patience for early stopping (default: 15)') + parser.add_argument('--benchmark', action='store_true', + help='enable performance benchmarking (tracks timing, throughput, GPU memory)') + parser.add_argument('--benchmark-output', type=Path, default='./benchmark_results.json', + help='path to save benchmark results JSON file (default: ./benchmark_results.json)') + parser.add_argument('--eval-output', type=Path, default='./evaluation_metrics.json', + help='path to save evaluation metrics JSON file (default: ./evaluation_metrics.json)') # CI TEST: Add smoke test flag parser.add_argument('--smoke-test', action='store_true', @@ -974,6 +1001,11 @@ def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") + # Initialize benchmark tracker + benchmark = TrainingBenchmark(device, enabled=args.benchmark) + if args.benchmark: + benchmark.print_hardware_info() + # Use the improved model model = UNet(input_channels=4, base_channels=32).to(device) @@ -1028,14 +1060,21 @@ def main(): num_epochs = args.epochs for epoch in range(start_epoch, num_epochs): - train_loss_epoch = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler, use_amp, amp_dtype) + train_loss_epoch = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler, use_amp, amp_dtype, benchmark) val_loss_epoch = validate_one_epoch(model, val_loader, criterion, device, use_amp, amp_dtype) train_loss.append(train_loss_epoch) val_loss.append(val_loss_epoch) current_lr = optimizer.param_groups[0]['lr'] - print(f"[Epoch {epoch+1}/{num_epochs}] LR={current_lr:.2e} TrainLoss={train_loss[-1]:.6f} ValLoss={val_loss[-1]:.6f}") + + # Enhanced logging with benchmark metrics + log_msg = f"[Epoch {epoch+1}/{num_epochs}] LR={current_lr:.2e} TrainLoss={train_loss[-1]:.6f} ValLoss={val_loss[-1]:.6f}" + if args.benchmark: + throughput = benchmark.get_throughput() + gpu_mem = benchmark.get_gpu_memory_usage() + log_msg += f" | Throughput={throughput:.2f} samples/s | GPU Mem={gpu_mem:.2f} GB" + print(log_msg) # Learning rate scheduling scheduler.step() @@ -1059,8 +1098,12 @@ def main(): print(f"Early stopping triggered after {epoch+1} epochs (patience={args.patience})") break - plot_training_history(train_loss, val_loss) + plot_training_history(train_loss, val_loss, save_path='plots/training_history.png') print("time (s) to train model: " + str(timer()-t2)) + + # Print and save benchmark summary + if args.benchmark: + benchmark.print_summary(output_file=args.benchmark_output) # CI TEST: Run additional tests if in smoke test mode if args.smoke_test: @@ -1134,12 +1177,71 @@ def main(): print("Loading best model for evaluation...") model.load_state_dict(torch.load(best_model_path, weights_only=True)) + # new evaluation code + # Evaluate model performance + if not args.smoke_test: + # print("\n" + "="*70) + # print("RUNNING MODEL EVALUATION") + # print("="*70) + + # # Evaluate on validation set + # print("\nEvaluating on validation set...") + val_evaluator = evaluate_model_on_dataset( + model, + val_dataset, # Use original dataset, not patch dataset + device, + use_amp=use_amp, + amp_dtype=amp_dtype, + threshold=0.5 + ) + + # Print and save validation metrics + val_evaluator.print_summary() + val_evaluator.save_json(args.eval_output) + + # Evaluate on training set + print("\nEvaluating on training set...") + train_evaluator = evaluate_model_on_dataset( + model, + train_dataset, + device, + use_amp=use_amp, + amp_dtype=amp_dtype, + threshold=0.5 + ) + + # Print and save training metrics + train_evaluator.print_summary() + train_eval_path = args.eval_output.parent / f"train_{args.eval_output.name}" + train_evaluator.save_json(train_eval_path) + + # Compare training vs validation to check for overfitting + train_global = train_evaluator.get_global_metrics() + val_global = val_evaluator.get_global_metrics() + + print("\n" + "="*70) + print("OVERFITTING CHECK") + print("="*70) + print(f"Training F1: {train_global['f1_score']:.4f}") + print(f"Validation F1: {val_global['f1_score']:.4f}") + print(f"Difference: {abs(train_global['f1_score'] - val_global['f1_score']):.4f}") + + if train_global['f1_score'] - val_global['f1_score'] > 0.05: + print("⚠ Warning: Possible overfitting detected (train F1 >> val F1)") + elif val_global['f1_score'] - train_global['f1_score'] > 0.05: + print("⚠ Warning: Unusual pattern (val F1 >> train F1)") + else: + print("✓ Model generalizes well to validation set") + print("="*70 + "\n") + + # ==================== END NEW EVALUATION CODE ==================== + # (D) Plotting after training model.eval() # switch to inference mode outDir = "plots" interpFac = 1 - # Evaluate on combined set for demonstration. Exam this part to see if save to remove + # Evaluate on combined set for demonstration if not args.smoke_test: train_fnums = range(args.trainFrameFirst, args.trainFrameLast) val_fnums = range(args.validationFrameFirst, args.validationFrameLast) @@ -1175,10 +1277,6 @@ def main(): pred_mask_bin = (pred_prob_np[0,0] > 0.5).astype(np.float32) - print(f"Frame {fnum} rotation {rotation} reflectionAxis {reflectionAxis}:") - print(f" Probabilities - min: {pred_prob_np.min():.5f}, max: {pred_prob_np.max():.5f}, mean: {pred_prob_np.mean():.5f}") - print(f" Binary Mask (X-points) - count of 1s: {np.sum(pred_mask_bin)} / {pred_mask_bin.size} pixels") - if args.plot: # Plot GROUND TRUTH plot_psi_contours_and_xpoints( diff --git a/benchmark.py b/benchmark.py new file mode 100644 index 0000000..1e965f5 --- /dev/null +++ b/benchmark.py @@ -0,0 +1,288 @@ +import time +import json +import torch +import platform + +import numpy as np +from pathlib import Path + +# Make psutil optional +try: + import psutil + PSUTIL_AVAILABLE = True +except ImportError: + PSUTIL_AVAILABLE = False + print("Warning: psutil not available. Hardware info will be limited.") + print("Install with: pip install psutil") + + +class TrainingBenchmark: + """ + Tracks and reports performance metrics during training. + + Measures: + - Hardware specifications + - Epoch timing + - Batch processing throughput + - GPU memory usage + - Samples processed per second + """ + + def __init__(self, device, enabled=True): + """ + Initialize benchmark tracker. + + Parameters: + device: torch.device - The device being used for training + enabled: bool - Whether benchmarking is active + """ + self.device = device + self.enabled = enabled + + if not enabled: + return + + self.epoch_times = [] + self.batch_times = [] + self.samples_processed = 0 + self.epoch_start = None + self.training_start = time.time() + + #collect hardware info + self.hardware_info = self._collect_hardware_info() + + def _collect_hardware_info(self): + """Collect system and GPU hardware information.""" + info = { + 'platform': platform.system(), + 'processor': platform.processor(), + 'python_version': platform.python_version(), + 'torch_version': torch.__version__, + } + + # Add psutil info if available + if PSUTIL_AVAILABLE: + info['cpu_count'] = psutil.cpu_count(logical=False) + info['cpu_count_logical'] = psutil.cpu_count(logical=True) + info['ram_gb'] = round(psutil.virtual_memory().total / (1024**3), 2) + else: + info['cpu_count'] = 'N/A (psutil not installed)' + info['cpu_count_logical'] = 'N/A (psutil not installed)' + info['ram_gb'] = 'N/A (psutil not installed)' + + if torch.cuda.is_available(): + info['gpu_name'] = torch.cuda.get_device_name(0) + info['gpu_count'] = torch.cuda.device_count() + info['cuda_version'] = torch.version.cuda + info['cudnn_version'] = torch.backends.cudnn.version() + info['gpu_memory_gb'] = round( + torch.cuda.get_device_properties(0).total_memory / (1024**3), 2 + ) + + #multi-GPU info + if torch.cuda.device_count() > 1: + info['all_gpus'] = [ + { + 'id': i, + 'name': torch.cuda.get_device_name(i), + 'memory_gb': round( + torch.cuda.get_device_properties(i).total_memory / (1024**3), 2 + ) + } + for i in range(torch.cuda.device_count()) + ] + else: + info['gpu_name'] = 'CPU only' + info['gpu_count'] = 0 + + return info + + def start_epoch(self): + """Mark the start of an epoch.""" + if not self.enabled: + return + self.epoch_start = time.time() + + def end_epoch(self): + """Mark the end of an epoch and record timing.""" + if not self.enabled: + return + if self.epoch_start is not None: + self.epoch_times.append(time.time() - self.epoch_start) + + def record_batch(self, batch_size, batch_time): + """ + Record metrics for a single batch. + + Parameters: + batch_size: int - Number of samples in the batch + batch_time: float - Time taken to process the batch (seconds) + """ + if not self.enabled: + return + self.batch_times.append(batch_time) + self.samples_processed += batch_size + + def get_throughput(self): + """ + Calculate average throughput. + + Returns: + float - Samples processed per second + """ + if not self.enabled or len(self.batch_times) == 0: + return 0 + total_time = sum(self.batch_times) + return self.samples_processed / total_time if total_time > 0 else 0 + + def get_gpu_memory_usage(self): + """ + Get current GPU memory usage. + + Returns: + float - Current GPU memory allocated in GB + """ + if not self.enabled: + return 0 + if torch.cuda.is_available(): + return torch.cuda.memory_allocated() / (1024**3) + return 0 + + def get_peak_gpu_memory(self): + """ + Get peak GPU memory usage. + + Returns: + float - Peak GPU memory allocated in GB + """ + if not self.enabled: + return 0 + if torch.cuda.is_available(): + return torch.cuda.max_memory_allocated() / (1024**3) + return 0 + + def print_hardware_info(self): + """Print hardware configuration at start of training.""" + if not self.enabled: + return + + print("\n" + "="*70) + print("HARDWARE CONFIGURATION") + print("="*70) + print(f"Platform: {self.hardware_info['platform']}") + print(f"CPU: {self.hardware_info['processor']}") + + if PSUTIL_AVAILABLE: + print(f"CPU Cores: {self.hardware_info['cpu_count']} physical, " + f"{self.hardware_info['cpu_count_logical']} logical") + print(f"RAM: {self.hardware_info['ram_gb']:.2f} GB") + else: + print(f"CPU Cores: {self.hardware_info['cpu_count']}") + print(f"RAM: {self.hardware_info['ram_gb']}") + + print(f"Python Version: {self.hardware_info['python_version']}") + print(f"PyTorch Version: {self.hardware_info['torch_version']}") + + if self.hardware_info['gpu_count'] > 0: + print(f"\nGPU Information:") + print(f" Primary GPU: {self.hardware_info['gpu_name']}") + print(f" GPU Count: {self.hardware_info['gpu_count']}") + print(f" VRAM: {self.hardware_info['gpu_memory_gb']:.2f} GB") + print(f" CUDA Version: {self.hardware_info['cuda_version']}") + print(f" cuDNN Version: {self.hardware_info['cudnn_version']}") + + if 'all_gpus' in self.hardware_info: + print(f"\n All GPUs:") + for gpu in self.hardware_info['all_gpus']: + print(f" [{gpu['id']}] {gpu['name']} ({gpu['memory_gb']:.2f} GB)") + else: + print("\nGPU: Not available (CPU only)") + + print("="*70 + "\n") + + def print_summary(self, output_file=None): + """ + Print comprehensive benchmark summary. + + Parameters: + output_file: Path or str - Optional file path to save JSON benchmark data + """ + if not self.enabled: + return + + total_training_time = time.time() - self.training_start + + print("\n" + "="*70) + print("BENCHMARK SUMMARY") + print("="*70) + + if len(self.epoch_times) > 0: + print(f"\nTraining Performance:") + print(f" Total epochs completed: {len(self.epoch_times)}") + print(f" Total training time: {total_training_time:.2f}s " + f"({total_training_time/60:.2f} min)") + print(f" Average epoch time: {np.mean(self.epoch_times):.2f}s") + print(f" Fastest epoch: {np.min(self.epoch_times):.2f}s") + print(f" Slowest epoch: {np.max(self.epoch_times):.2f}s") + print(f" Epoch time std dev: {np.std(self.epoch_times):.2f}s") + + if len(self.batch_times) > 0: + print(f"\nThroughput Metrics:") + print(f" Total samples processed: {self.samples_processed:,}") + print(f" Total batches processed: {len(self.batch_times):,}") + print(f" Average batch time: {np.mean(self.batch_times)*1000:.2f}ms") + print(f" Throughput: {self.get_throughput():.2f} samples/sec") + print(f" Time per sample: {1000/self.get_throughput():.2f}ms") + + if torch.cuda.is_available(): + print(f"\nGPU Memory Usage:") + print(f" Current allocation: {self.get_gpu_memory_usage():.2f} GB") + print(f" Peak allocation: {self.get_peak_gpu_memory():.2f} GB") + print(f" Total GPU memory: {self.hardware_info['gpu_memory_gb']:.2f} GB") + print(f" Peak utilization: " + f"{100*self.get_peak_gpu_memory()/self.hardware_info['gpu_memory_gb']:.1f}%") + + print("="*70 + "\n") + + if output_file is not None: + self.save_json(output_file) + + def save_json(self, output_file): + """ + Save benchmark data to JSON file. + + Parameters: + output_file: Path or str - File path to save benchmark data + """ + if not self.enabled: + return + + benchmark_data = { + 'hardware': self.hardware_info, + 'training': { + 'total_epochs': len(self.epoch_times), + 'total_training_time_sec': time.time() - self.training_start, + 'epoch_times_sec': self.epoch_times, + 'avg_epoch_time_sec': float(np.mean(self.epoch_times)) if self.epoch_times else 0, + 'min_epoch_time_sec': float(np.min(self.epoch_times)) if self.epoch_times else 0, + 'max_epoch_time_sec': float(np.max(self.epoch_times)) if self.epoch_times else 0, + }, + 'throughput': { + 'total_samples': self.samples_processed, + 'total_batches': len(self.batch_times), + 'avg_batch_time_sec': float(np.mean(self.batch_times)) if self.batch_times else 0, + 'samples_per_sec': self.get_throughput(), + }, + 'gpu_memory': { + 'peak_allocation_gb': self.get_peak_gpu_memory(), + 'final_allocation_gb': self.get_gpu_memory_usage(), + } if torch.cuda.is_available() else None + } + + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w') as f: + json.dump(benchmark_data, f, indent=2) + + print(f"Benchmark data saved to: {output_path}") \ No newline at end of file diff --git a/eval_metrics.py b/eval_metrics.py new file mode 100644 index 0000000..7ea7e60 --- /dev/null +++ b/eval_metrics.py @@ -0,0 +1,266 @@ +""" +Model evaluation metrics for X-point detection. + +This module provides functions to compute detailed performance metrics +for the X-point detection model, including per-frame and global statistics. +""" + +import numpy as np +import json +from pathlib import Path +import torch +from torch.amp import autocast + + +class ModelEvaluator: + """ + Evaluates model performance on X-point detection task. + + Computes metrics including: + - True Positives (TP): X-point pixels correctly identified + - False Positives (FP): Background pixels incorrectly labeled as X-points + - False Negatives (FN): X-point pixels that were missed + - True Negatives (TN): Background pixels correctly identified + + Metrics calculated: + - Accuracy: (TP + TN) / (TP + TN + FP + FN) + - Precision: TP / (TP + FP) + - Recall: TP / (TP + FN) + - F1 Score: 2 * (Precision * Recall) / (Precision + Recall) + - IoU: TP / (TP + FP + FN) + """ + + def __init__(self, threshold=0.5): + """ + Initialize evaluator. + + Parameters: + threshold: float - Probability threshold for binary classification (default: 0.5) + """ + self.threshold = threshold + self.reset() + + def reset(self): + """Reset all accumulated metrics.""" + self.global_tp = 0 + self.global_fp = 0 + self.global_fn = 0 + self.global_tn = 0 + self.frame_metrics = [] + + def compute_frame_metrics(self, pred_probs, ground_truth): + """ + Compute metrics for a single frame. + + Parameters: + pred_probs: np.ndarray - Predicted probabilities, shape [H, W] + ground_truth: np.ndarray - Ground truth binary mask, shape [H, W] + + Returns: + dict - Dictionary containing TP, FP, FN, TN and derived metrics + """ + # Binarize predictions + pred_binary = (pred_probs > self.threshold).astype(np.float32) + gt_binary = (ground_truth > 0.5).astype(np.float32) + + # Compute confusion matrix elements + tp = np.sum((pred_binary == 1) & (gt_binary == 1)) + fp = np.sum((pred_binary == 1) & (gt_binary == 0)) + fn = np.sum((pred_binary == 0) & (gt_binary == 1)) + tn = np.sum((pred_binary == 0) & (gt_binary == 0)) + + # Compute derived metrics + total = tp + fp + fn + tn + accuracy = (tp + tn) / total if total > 0 else 0.0 + precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 + iou = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0.0 + + return { + 'tp': int(tp), + 'fp': int(fp), + 'fn': int(fn), + 'tn': int(tn), + 'accuracy': float(accuracy), + 'precision': float(precision), + 'recall': float(recall), + 'f1_score': float(f1), + 'iou': float(iou) + } + + def add_frame(self, pred_probs, ground_truth, frame_id=None): + """ + Add a frame's results to the evaluation. + + Parameters: + pred_probs: np.ndarray - Predicted probabilities + ground_truth: np.ndarray - Ground truth binary mask + frame_id: int or str - Optional frame identifier + """ + metrics = self.compute_frame_metrics(pred_probs, ground_truth) + + # Add to global counts + self.global_tp += metrics['tp'] + self.global_fp += metrics['fp'] + self.global_fn += metrics['fn'] + self.global_tn += metrics['tn'] + + # Store frame metrics + if frame_id is not None: + metrics['frame_id'] = frame_id + self.frame_metrics.append(metrics) + + def get_global_metrics(self): + """ + Compute global metrics across all frames. + + Returns: + dict - Global metrics computed from accumulated confusion matrix + """ + total = self.global_tp + self.global_fp + self.global_fn + self.global_tn + + metrics = { + 'global_tp': int(self.global_tp), + 'global_fp': int(self.global_fp), + 'global_fn': int(self.global_fn), + 'global_tn': int(self.global_tn), + 'total_pixels': int(total), + 'accuracy': (self.global_tp + self.global_tn) / total if total > 0 else 0.0, + 'precision': self.global_tp / (self.global_tp + self.global_fp) + if (self.global_tp + self.global_fp) > 0 else 0.0, + 'recall': self.global_tp / (self.global_tp + self.global_fn) + if (self.global_tp + self.global_fn) > 0 else 0.0, + 'iou': self.global_tp / (self.global_tp + self.global_fp + self.global_fn) + if (self.global_tp + self.global_fp + self.global_fn) > 0 else 0.0, + } + + # Compute F1 from global precision and recall + if (metrics['precision'] + metrics['recall']) > 0: + metrics['f1_score'] = 2 * metrics['precision'] * metrics['recall'] / \ + (metrics['precision'] + metrics['recall']) + else: + metrics['f1_score'] = 0.0 + + return metrics + + def get_frame_statistics(self): + """ + Compute statistics across all frames. + + Returns: + dict - Mean and standard deviation for each metric across frames + """ + if not self.frame_metrics: + return {} + + metrics_arrays = { + key: np.array([frame[key] for frame in self.frame_metrics]) + for key in ['accuracy', 'precision', 'recall', 'f1_score', 'iou'] + } + + stats = {} + for metric_name, values in metrics_arrays.items(): + stats[f'{metric_name}_mean'] = float(np.mean(values)) + stats[f'{metric_name}_std'] = float(np.std(values)) + stats[f'{metric_name}_min'] = float(np.min(values)) + stats[f'{metric_name}_max'] = float(np.max(values)) + + return stats + + def print_summary(self): + """Print comprehensive evaluation summary.""" + print("\n" + "="*70) + print("MODEL EVALUATION METRICS") + print("="*70) + + global_metrics = self.get_global_metrics() + + print("\nGlobal Metrics (across all frames):") + print(f" Total pixels evaluated: {global_metrics['total_pixels']:,}") + print(f" True Positives (TP): {global_metrics['global_tp']:,}") + print(f" False Positives (FP): {global_metrics['global_fp']:,}") + print(f" False Negatives (FN): {global_metrics['global_fn']:,}") + print(f" True Negatives (TN): {global_metrics['global_tn']:,}") + print(f"\n Accuracy: {global_metrics['accuracy']:.4f}") + print(f" Precision: {global_metrics['precision']:.4f}") + print(f" Recall: {global_metrics['recall']:.4f}") + print(f" F1 Score: {global_metrics['f1_score']:.4f}") + print(f" IoU: {global_metrics['iou']:.4f}") + + if self.frame_metrics: + print(f"\nPer-Frame Statistics ({len(self.frame_metrics)} frames):") + stats = self.get_frame_statistics() + + for metric in ['accuracy', 'precision', 'recall', 'f1_score', 'iou']: + mean = stats[f'{metric}_mean'] + std = stats[f'{metric}_std'] + min_val = stats[f'{metric}_min'] + max_val = stats[f'{metric}_max'] + print(f" {metric.replace('_', ' ').title():20s} " + f"mean={mean:.4f} ±{std:.4f} " + f"[{min_val:.4f}, {max_val:.4f}]") + + print("="*70 + "\n") + + def save_json(self, output_file): + """ + Save evaluation results to JSON file. + + Parameters: + output_file: Path or str - File path to save evaluation data + """ + evaluation_data = { + 'global_metrics': self.get_global_metrics(), + 'frame_statistics': self.get_frame_statistics(), + 'per_frame_metrics': self.frame_metrics, + 'threshold': self.threshold, + 'num_frames': len(self.frame_metrics) + } + + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w') as f: + json.dump(evaluation_data, f, indent=2) + + print(f"Evaluation metrics saved to: {output_path}") + + +def evaluate_model_on_dataset(model, dataset, device, use_amp=False, + amp_dtype=torch.float16, threshold=0.5): + """ + Evaluate model on entire dataset and return metrics. + + Parameters: + model: nn.Module - The trained model + dataset: Dataset - Dataset to evaluate on (XPointDataset, not patch dataset) + device: torch.device - Device to run evaluation on + use_amp: bool - Whether to use automatic mixed precision + amp_dtype: torch.dtype - Data type for mixed precision + threshold: float - Threshold for binary classification + + Returns: + ModelEvaluator - Evaluator object with computed metrics + """ + model.eval() + evaluator = ModelEvaluator(threshold=threshold) + + with torch.no_grad(): + for item in dataset: + fnum = item["fnum"] + all_torch = item["all"].unsqueeze(0).to(device) + mask_gt = item["mask"][0].cpu().numpy() # Remove channel dimension + + # Get prediction + with autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp): + pred_mask = model(all_torch) + pred_prob = torch.sigmoid(pred_mask) + + # Convert to numpy (handle BFloat16) + pred_prob_np = pred_prob[0, 0].float().cpu().numpy() + + # Add to evaluator + evaluator.add_frame(pred_prob_np, mask_gt, frame_id=fnum) + + return evaluator \ No newline at end of file