Skip to content

Commit 5c8bb6d

Browse files
committed
Removed unnecessary files
1 parent 9fbdff2 commit 5c8bb6d

File tree

3 files changed

+246
-0
lines changed

3 files changed

+246
-0
lines changed

XPointMLTest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
# Import evaluation metrics module
3232
from eval_metrics import ModelEvaluator, evaluate_model_on_dataset
3333

34+
# Import git utils
35+
from git_utils import print_git_info
36+
3437
def set_seed(seed):
3538
"""
3639
Set random seed for reproducibility across all libraries
@@ -1070,6 +1073,11 @@ def load_model_checkpoint(model, optimizer, checkpoint_path, scaler=None):
10701073

10711074

10721075
def main():
1076+
# Print git repository information
1077+
# Use the directory of the current script as the repo path
1078+
repo_path = os.path.dirname(os.path.abspath(__file__))
1079+
print_git_info(repo_path)
1080+
10731081
args = parseCommandLineArgs()
10741082

10751083
# Set seed for reproducibility

git_utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import subprocess
2+
import os
3+
4+
def get_git_info(repo_path='.'):
5+
"""
6+
Retrieves git information: commit hash, remote URL, and branch name.
7+
8+
Args:
9+
repo_path (str): Path to the git repository. Defaults to current directory.
10+
11+
Returns:
12+
dict: Dictionary containing 'commit_hash', 'remote_url', and 'branch_name'.
13+
Values are None if retrieval fails.
14+
"""
15+
def run_git_command(command):
16+
try:
17+
result = subprocess.run(
18+
command,
19+
cwd=repo_path,
20+
stdout=subprocess.PIPE,
21+
stderr=subprocess.PIPE,
22+
text=True,
23+
check=True
24+
)
25+
return result.stdout.strip()
26+
except (subprocess.CalledProcessError, FileNotFoundError):
27+
return None
28+
29+
commit_hash = run_git_command(['git', 'rev-parse', 'HEAD'])
30+
remote_url = run_git_command(['git', 'config', '--get', 'remote.origin.url'])
31+
branch_name = run_git_command(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
32+
33+
return {
34+
'commit_hash': commit_hash,
35+
'remote_url': remote_url,
36+
'branch_name': branch_name
37+
}
38+
39+
def print_git_info(repo_path='.'):
40+
"""
41+
Prints git information to stdout.
42+
"""
43+
info = get_git_info(repo_path)
44+
print("-" * 30)
45+
print("Git Repository Information:")
46+
if info['commit_hash']:
47+
print(f"Commit Hash: {info['commit_hash']}")
48+
else:
49+
print("Commit Hash: Not available")
50+
51+
if info['branch_name']:
52+
print(f"Branch Name: {info['branch_name']}")
53+
else:
54+
print("Branch Name: Not available")
55+
56+
if info['remote_url']:
57+
print(f"Remote URL: {info['remote_url']}")
58+
else:
59+
print("Remote URL: Not available")
60+
print("-" * 30)
61+
62+
if __name__ == "__main__":
63+
print_git_info()

hessian_comparison.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import time
2+
import torch
3+
import numpy as np
4+
import argparse
5+
import sys
6+
import os
7+
from pathlib import Path
8+
9+
# Ensure we can import from the current directory and pgkylFrontEnd (via PYTHONPATH)
10+
sys.path.append(str(Path(__file__).parent))
11+
12+
try:
13+
from XPointMLTest import UNet, loadPgkylDataFromCache, cachedPgkylDataExists
14+
from utils import auxFuncs, gkData
15+
except ImportError as e:
16+
print(f"Error importing modules: {e}")
17+
print("Make sure you have sourced envPyTorch.sh and are running with correct PYTHONPATH.")
18+
sys.exit(1)
19+
20+
def compare_hessian_vs_ml(param_file, cache_dir, model_path, frame_list, device='cuda'):
21+
print(f"Comparing Hessian vs ML on frames {frame_list}")
22+
print(f"Device: {device}")
23+
print(f"Model: {model_path}")
24+
print(f"Param File: {param_file}")
25+
print(f"Cache Dir: {cache_dir}")
26+
27+
# Load Model
28+
# UNet signature: def __init__(self, input_channels=4, base_channels=32, *, dropout_rate):
29+
model = UNet(input_channels=4, base_channels=32, dropout_rate=0.15).to(device)
30+
try:
31+
checkpoint = torch.load(model_path, map_location=device)
32+
if 'model_state_dict' in checkpoint:
33+
model.load_state_dict(checkpoint['model_state_dict'])
34+
else:
35+
model.load_state_dict(checkpoint)
36+
except Exception as e:
37+
print(f"Failed to load model: {e}")
38+
sys.exit(1)
39+
40+
model.eval()
41+
42+
hessian_times = []
43+
ml_times = []
44+
45+
# Warmup
46+
dummy_input = torch.randn(1, 4, 1024, 1024).to(device)
47+
with torch.no_grad():
48+
_ = model(dummy_input)
49+
50+
print(f"{'Frame':<10} | {'Hessian (s)':<15} | {'ML (s)':<15} | {'Speedup':<10}")
51+
print("-" * 60)
52+
53+
cache_path = Path(cache_dir) if cache_dir else None
54+
55+
for fnum in frame_list:
56+
try:
57+
psi = None
58+
dx = None
59+
60+
# Try loading from cache first
61+
if cache_path and cachedPgkylDataExists(cache_path, fnum, "psi"):
62+
fields_to_load = {"psi": None, "coords": None}
63+
loaded = loadPgkylDataFromCache(cache_path, fnum, fields_to_load)
64+
psi = loaded["psi"]
65+
coords = loaded["coords"]
66+
# Calculate dx from coords
67+
dx = [c[1] - c[0] for c in coords]
68+
else:
69+
# Fallback to gkData (might fail if getData.py is buggy)
70+
params = {}
71+
params["polyOrderOverride"] = 0
72+
var = gkData.gkData(str(param_file), fnum, 'psi', params).compactRead()
73+
psi = var.data
74+
dx = var.dx
75+
76+
if psi is None:
77+
print(f"Could not load data for frame {fnum}")
78+
continue
79+
80+
# --- Measure Hessian Time ---
81+
t0 = time.time()
82+
83+
# Replicating Hessian logic
84+
critPoints = auxFuncs.getCritPoints(psi)
85+
[xpts, optsMax, optsMin] = auxFuncs.getXOPoints(psi, critPoints)
86+
87+
t1 = time.time()
88+
hessian_time = t1 - t0
89+
hessian_times.append(hessian_time)
90+
91+
# --- Measure ML Time ---
92+
# Preprocess - Calculate derived fields
93+
[df_dx,df_dy,df_dz] = auxFuncs.genGradient(psi,dx)
94+
[d2f_dxdx,d2f_dxdy,d2f_dxdz] = auxFuncs.genGradient(df_dx,dx)
95+
[d2f_dydx,d2f_dydy,d2f_dydz] = auxFuncs.genGradient(df_dy,dx)
96+
bx = df_dy
97+
by = -df_dx
98+
# mu0 is usually 1.0 in normalized units or available in var.mu0
99+
# If we loaded from cache, we don't have var.mu0.
100+
# Assuming mu0=1.0 for now as it's common in normalized simulations,
101+
# or we could read it from param file, but let's stick to 1.0 or check if we can get it.
102+
# In XPointMLTest.py: jz = -(d2f_dxdx + d2f_dydy) / var.mu0
103+
# In getConst.py: self.mu0 = mu0 (from param file).
104+
# Let's assume mu0=1.0 to avoid reading param file again, or just use 1.0.
105+
mu0 = 1.0
106+
jz = -(d2f_dxdx + d2f_dydy) / mu0
107+
108+
# Normalize (using same logic as XPointMLTest.py)
109+
psi_norm = (psi - psi.mean()) / (psi.std() + 1e-8)
110+
bx_norm = (bx - bx.mean()) / (bx.std() + 1e-8)
111+
by_norm = (by - by.mean()) / (by.std() + 1e-8)
112+
jz_norm = (jz - jz.mean()) / (jz.std() + 1e-8)
113+
114+
# Stack
115+
psi_torch = torch.from_numpy(psi_norm).float().unsqueeze(0)
116+
bx_torch = torch.from_numpy(bx_norm).float().unsqueeze(0)
117+
by_torch = torch.from_numpy(by_norm).float().unsqueeze(0)
118+
jz_torch = torch.from_numpy(jz_norm).float().unsqueeze(0)
119+
120+
input_tensor = torch.cat((psi_torch, bx_torch, by_torch, jz_torch)).unsqueeze(0).to(device)
121+
122+
if device == 'cuda':
123+
torch.cuda.synchronize()
124+
t2 = time.time()
125+
126+
with torch.no_grad():
127+
output = model(input_tensor)
128+
prob = torch.sigmoid(output)
129+
mask = (prob > 0.5).float()
130+
131+
if device == 'cuda':
132+
torch.cuda.synchronize()
133+
t3 = time.time()
134+
135+
ml_time = t3 - t2
136+
ml_times.append(ml_time)
137+
138+
print(f"{fnum:<10} | {hessian_time:<15.4f} | {ml_time:<15.4f} | {hessian_time/ml_time:<10.2f}")
139+
140+
except Exception as e:
141+
print(f"Error processing frame {fnum}: {e}")
142+
import traceback
143+
traceback.print_exc()
144+
continue
145+
146+
if hessian_times and ml_times:
147+
avg_hessian = np.mean(hessian_times)
148+
avg_ml = np.mean(ml_times)
149+
150+
print("\n" + "="*60)
151+
print(f"Average Hessian Time: {avg_hessian:.4f}s")
152+
print(f"Average ML Time: {avg_ml:.4f}s")
153+
print(f"Average Speedup: {avg_hessian/avg_ml:.2f}x")
154+
print("="*60)
155+
else:
156+
print("No frames processed successfully.")
157+
158+
if __name__ == "__main__":
159+
parser = argparse.ArgumentParser(description="Compare Hessian-based vs ML-based X-point detection performance.")
160+
parser.add_argument('--paramFile', type=str, required=True, help="Path to the parameter file")
161+
parser.add_argument('--xptCacheDir', type=str, default=None, help="Path to cache directory (optional)")
162+
parser.add_argument('--modelPath', type=str, required=True, help="Path to the trained model checkpoint (.pt)")
163+
parser.add_argument('--frames', type=str, default="141-150", help="Range of frames (e.g., '141-150' or '141,142,143')")
164+
parser.add_argument('--device', type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run ML model on")
165+
166+
args = parser.parse_args()
167+
168+
# Parse frames
169+
if '-' in args.frames:
170+
start, end = map(int, args.frames.split('-'))
171+
frames = range(start, end + 1)
172+
else:
173+
frames = [int(x) for x in args.frames.split(',')]
174+
175+
compare_hessian_vs_ml(args.paramFile, args.xptCacheDir, args.modelPath, frames, args.device)

0 commit comments

Comments
 (0)