|
| 1 | +import time |
| 2 | +import torch |
| 3 | +import numpy as np |
| 4 | +import argparse |
| 5 | +import sys |
| 6 | +import os |
| 7 | +from pathlib import Path |
| 8 | + |
| 9 | +# Ensure we can import from the current directory and pgkylFrontEnd (via PYTHONPATH) |
| 10 | +sys.path.append(str(Path(__file__).parent)) |
| 11 | + |
| 12 | +try: |
| 13 | + from XPointMLTest import UNet, loadPgkylDataFromCache, cachedPgkylDataExists |
| 14 | + from utils import auxFuncs, gkData |
| 15 | +except ImportError as e: |
| 16 | + print(f"Error importing modules: {e}") |
| 17 | + print("Make sure you have sourced envPyTorch.sh and are running with correct PYTHONPATH.") |
| 18 | + sys.exit(1) |
| 19 | + |
| 20 | +def compare_hessian_vs_ml(param_file, cache_dir, model_path, frame_list, device='cuda'): |
| 21 | + print(f"Comparing Hessian vs ML on frames {frame_list}") |
| 22 | + print(f"Device: {device}") |
| 23 | + print(f"Model: {model_path}") |
| 24 | + print(f"Param File: {param_file}") |
| 25 | + print(f"Cache Dir: {cache_dir}") |
| 26 | + |
| 27 | + # Load Model |
| 28 | + # UNet signature: def __init__(self, input_channels=4, base_channels=32, *, dropout_rate): |
| 29 | + model = UNet(input_channels=4, base_channels=32, dropout_rate=0.15).to(device) |
| 30 | + try: |
| 31 | + checkpoint = torch.load(model_path, map_location=device) |
| 32 | + if 'model_state_dict' in checkpoint: |
| 33 | + model.load_state_dict(checkpoint['model_state_dict']) |
| 34 | + else: |
| 35 | + model.load_state_dict(checkpoint) |
| 36 | + except Exception as e: |
| 37 | + print(f"Failed to load model: {e}") |
| 38 | + sys.exit(1) |
| 39 | + |
| 40 | + model.eval() |
| 41 | + |
| 42 | + hessian_times = [] |
| 43 | + ml_times = [] |
| 44 | + |
| 45 | + # Warmup |
| 46 | + dummy_input = torch.randn(1, 4, 1024, 1024).to(device) |
| 47 | + with torch.no_grad(): |
| 48 | + _ = model(dummy_input) |
| 49 | + |
| 50 | + print(f"{'Frame':<10} | {'Hessian (s)':<15} | {'ML (s)':<15} | {'Speedup':<10}") |
| 51 | + print("-" * 60) |
| 52 | + |
| 53 | + cache_path = Path(cache_dir) if cache_dir else None |
| 54 | + |
| 55 | + for fnum in frame_list: |
| 56 | + try: |
| 57 | + psi = None |
| 58 | + dx = None |
| 59 | + |
| 60 | + # Try loading from cache first |
| 61 | + if cache_path and cachedPgkylDataExists(cache_path, fnum, "psi"): |
| 62 | + fields_to_load = {"psi": None, "coords": None} |
| 63 | + loaded = loadPgkylDataFromCache(cache_path, fnum, fields_to_load) |
| 64 | + psi = loaded["psi"] |
| 65 | + coords = loaded["coords"] |
| 66 | + # Calculate dx from coords |
| 67 | + dx = [c[1] - c[0] for c in coords] |
| 68 | + else: |
| 69 | + # Fallback to gkData (might fail if getData.py is buggy) |
| 70 | + params = {} |
| 71 | + params["polyOrderOverride"] = 0 |
| 72 | + var = gkData.gkData(str(param_file), fnum, 'psi', params).compactRead() |
| 73 | + psi = var.data |
| 74 | + dx = var.dx |
| 75 | + |
| 76 | + if psi is None: |
| 77 | + print(f"Could not load data for frame {fnum}") |
| 78 | + continue |
| 79 | + |
| 80 | + # --- Measure Hessian Time --- |
| 81 | + t0 = time.time() |
| 82 | + |
| 83 | + # Replicating Hessian logic |
| 84 | + critPoints = auxFuncs.getCritPoints(psi) |
| 85 | + [xpts, optsMax, optsMin] = auxFuncs.getXOPoints(psi, critPoints) |
| 86 | + |
| 87 | + t1 = time.time() |
| 88 | + hessian_time = t1 - t0 |
| 89 | + hessian_times.append(hessian_time) |
| 90 | + |
| 91 | + # --- Measure ML Time --- |
| 92 | + # Preprocess - Calculate derived fields |
| 93 | + [df_dx,df_dy,df_dz] = auxFuncs.genGradient(psi,dx) |
| 94 | + [d2f_dxdx,d2f_dxdy,d2f_dxdz] = auxFuncs.genGradient(df_dx,dx) |
| 95 | + [d2f_dydx,d2f_dydy,d2f_dydz] = auxFuncs.genGradient(df_dy,dx) |
| 96 | + bx = df_dy |
| 97 | + by = -df_dx |
| 98 | + # mu0 is usually 1.0 in normalized units or available in var.mu0 |
| 99 | + # If we loaded from cache, we don't have var.mu0. |
| 100 | + # Assuming mu0=1.0 for now as it's common in normalized simulations, |
| 101 | + # or we could read it from param file, but let's stick to 1.0 or check if we can get it. |
| 102 | + # In XPointMLTest.py: jz = -(d2f_dxdx + d2f_dydy) / var.mu0 |
| 103 | + # In getConst.py: self.mu0 = mu0 (from param file). |
| 104 | + # Let's assume mu0=1.0 to avoid reading param file again, or just use 1.0. |
| 105 | + mu0 = 1.0 |
| 106 | + jz = -(d2f_dxdx + d2f_dydy) / mu0 |
| 107 | + |
| 108 | + # Normalize (using same logic as XPointMLTest.py) |
| 109 | + psi_norm = (psi - psi.mean()) / (psi.std() + 1e-8) |
| 110 | + bx_norm = (bx - bx.mean()) / (bx.std() + 1e-8) |
| 111 | + by_norm = (by - by.mean()) / (by.std() + 1e-8) |
| 112 | + jz_norm = (jz - jz.mean()) / (jz.std() + 1e-8) |
| 113 | + |
| 114 | + # Stack |
| 115 | + psi_torch = torch.from_numpy(psi_norm).float().unsqueeze(0) |
| 116 | + bx_torch = torch.from_numpy(bx_norm).float().unsqueeze(0) |
| 117 | + by_torch = torch.from_numpy(by_norm).float().unsqueeze(0) |
| 118 | + jz_torch = torch.from_numpy(jz_norm).float().unsqueeze(0) |
| 119 | + |
| 120 | + input_tensor = torch.cat((psi_torch, bx_torch, by_torch, jz_torch)).unsqueeze(0).to(device) |
| 121 | + |
| 122 | + if device == 'cuda': |
| 123 | + torch.cuda.synchronize() |
| 124 | + t2 = time.time() |
| 125 | + |
| 126 | + with torch.no_grad(): |
| 127 | + output = model(input_tensor) |
| 128 | + prob = torch.sigmoid(output) |
| 129 | + mask = (prob > 0.5).float() |
| 130 | + |
| 131 | + if device == 'cuda': |
| 132 | + torch.cuda.synchronize() |
| 133 | + t3 = time.time() |
| 134 | + |
| 135 | + ml_time = t3 - t2 |
| 136 | + ml_times.append(ml_time) |
| 137 | + |
| 138 | + print(f"{fnum:<10} | {hessian_time:<15.4f} | {ml_time:<15.4f} | {hessian_time/ml_time:<10.2f}") |
| 139 | + |
| 140 | + except Exception as e: |
| 141 | + print(f"Error processing frame {fnum}: {e}") |
| 142 | + import traceback |
| 143 | + traceback.print_exc() |
| 144 | + continue |
| 145 | + |
| 146 | + if hessian_times and ml_times: |
| 147 | + avg_hessian = np.mean(hessian_times) |
| 148 | + avg_ml = np.mean(ml_times) |
| 149 | + |
| 150 | + print("\n" + "="*60) |
| 151 | + print(f"Average Hessian Time: {avg_hessian:.4f}s") |
| 152 | + print(f"Average ML Time: {avg_ml:.4f}s") |
| 153 | + print(f"Average Speedup: {avg_hessian/avg_ml:.2f}x") |
| 154 | + print("="*60) |
| 155 | + else: |
| 156 | + print("No frames processed successfully.") |
| 157 | + |
| 158 | +if __name__ == "__main__": |
| 159 | + parser = argparse.ArgumentParser(description="Compare Hessian-based vs ML-based X-point detection performance.") |
| 160 | + parser.add_argument('--paramFile', type=str, required=True, help="Path to the parameter file") |
| 161 | + parser.add_argument('--xptCacheDir', type=str, default=None, help="Path to cache directory (optional)") |
| 162 | + parser.add_argument('--modelPath', type=str, required=True, help="Path to the trained model checkpoint (.pt)") |
| 163 | + parser.add_argument('--frames', type=str, default="141-150", help="Range of frames (e.g., '141-150' or '141,142,143')") |
| 164 | + parser.add_argument('--device', type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run ML model on") |
| 165 | + |
| 166 | + args = parser.parse_args() |
| 167 | + |
| 168 | + # Parse frames |
| 169 | + if '-' in args.frames: |
| 170 | + start, end = map(int, args.frames.split('-')) |
| 171 | + frames = range(start, end + 1) |
| 172 | + else: |
| 173 | + frames = [int(x) for x in args.frames.split(',')] |
| 174 | + |
| 175 | + compare_hessian_vs_ml(args.paramFile, args.xptCacheDir, args.modelPath, frames, args.device) |
0 commit comments