Skip to content

Commit 915c48c

Browse files
committed
clean up modules
1 parent c698457 commit 915c48c

File tree

1 file changed

+37
-230
lines changed

1 file changed

+37
-230
lines changed

calphy/phase.py

Lines changed: 37 additions & 230 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,10 @@
3838
import calphy.helpers as ph
3939
from calphy.errors import *
4040
from calphy.input import generate_metadata
41+
import calphy.structural_monitoring as sm
4142

42-
# Import FrameAccumulator for structural monitoring
43-
try:
44-
from structid import FrameAccumulator
45-
from structid.descriptors import SteinhardtDescriptor
46-
from structid.detectors import (
47-
GPRDetector,
48-
AdaptiveZScoreDetector,
49-
CumulativeMeanDetector,
50-
)
51-
52-
STRUCTID_AVAILABLE = True
53-
except ImportError:
54-
STRUCTID_AVAILABLE = False
43+
# Check if structid is available
44+
STRUCTID_AVAILABLE = sm.STRUCTID_AVAILABLE
5545

5646

5747
class Phase:
@@ -1069,26 +1059,20 @@ def _reversible_scaling_forward(self, iteration=1):
10691059
f"Temperature interval for monitoring: {self.calc.structural_monitoring.temperature_interval} K"
10701060
)
10711061

1072-
# Create descriptor and detector
1073-
descriptor = SteinhardtDescriptor(
1074-
l_values=tuple(self.calc.structural_monitoring.l_values),
1062+
# Initialize accumulator using structural_monitoring module
1063+
accumulator = sm.initialize_accumulator(
1064+
l_values=self.calc.structural_monitoring.l_values,
10751065
cutoff=self.calc.structural_monitoring.cutoff,
1066+
detector_type="cumulative",
1067+
threshold=3,
10761068
)
1077-
detector = CumulativeMeanDetector(threshold=3)
1078-
1079-
# Initialize accumulator
1080-
accumulator = FrameAccumulator(descriptor=descriptor, detector=detector)
10811069

10821070
# Set up dump for equilibration trajectory to pre-train
1083-
pretrain_dump = os.path.join(self.simfolder, "_tmp_pretrain_traj.dump")
1084-
pretrain_dump_interval = (
1085-
100 # Fixed interval for equilibration pre-training
1086-
)
1087-
lmp.command(
1088-
f"dump _pretrain all custom {pretrain_dump_interval} {pretrain_dump} id type x y z"
1071+
pretrain_dump = sm.setup_pretraining_dump(
1072+
lmp=lmp, simfolder=self.simfolder, dump_interval=100
10891073
)
10901074
self.logger.info(
1091-
f"Dumping equilibration trajectory every {pretrain_dump_interval} steps for pre-training"
1075+
"Dumping equilibration trajectory every 100 steps for pre-training"
10921076
)
10931077

10941078
self.logger.info(f"Starting equilibration with constrained com: {iteration}")
@@ -1099,35 +1083,16 @@ def _reversible_scaling_forward(self, iteration=1):
10991083
if accumulator is not None and pretrain_dump is not None:
11001084
lmp.command("undump _pretrain")
11011085

1102-
if os.path.exists(pretrain_dump):
1103-
from ase.io import read
1104-
1105-
try:
1106-
self.logger.info(
1107-
"Pre-training FrameAccumulator on equilibration trajectory"
1108-
)
1109-
# Read all frames from equilibration
1110-
traj = read(pretrain_dump, format="lammps-dump-text", index=":")
1111-
self.logger.info(
1112-
f"Read {len(traj)} frames from equilibration trajectory"
1113-
)
1086+
success = sm.pretrain_accumulator(
1087+
accumulator=accumulator, pretrain_dump=pretrain_dump, logger=self.logger
1088+
)
11141089

1115-
# Pre-train on all frames
1116-
with warnings.catch_warnings():
1117-
warnings.filterwarnings(
1118-
"ignore", category=Warning, module="sklearn"
1119-
)
1120-
for atoms in traj:
1121-
accumulator.calculate_single_frame(atoms)
1090+
if not success:
1091+
accumulator = None
11221092

1123-
self.logger.info("FrameAccumulator pre-trained successfully")
1124-
except Exception as e:
1125-
self.logger.warning(f"Failed to pre-train FrameAccumulator: {e}")
1126-
accumulator = None
1127-
finally:
1128-
# Clean up trajectory file
1129-
if os.path.exists(pretrain_dump):
1130-
os.remove(pretrain_dump)
1093+
# Clean up trajectory file
1094+
if os.path.exists(pretrain_dump):
1095+
os.remove(pretrain_dump)
11311096

11321097
lmp.command("variable flambda equal ramp(${li},${lf})")
11331098
lmp.command("variable blambda equal ramp(${lf},${li})")
@@ -1305,96 +1270,28 @@ def _reversible_scaling_forward(self, iteration=1):
13051270
# Unfix to avoid duplicates at boundaries
13061271
lmp.command("unfix f3")
13071272

1308-
# Get atomic data directly from LAMMPS instead of writing/reading file
1273+
# Get atomic data directly from LAMMPS
13091274
try:
1310-
from ase import Atoms
1311-
1312-
# Gather atomic positions and types
1313-
x = lmp.gather_atoms("x", 3) # positions (natoms × 3)
1314-
atom_types = lmp.gather_atoms("type", 0) # atom types (natoms,)
1275+
atoms = sm.extract_atoms_from_lammps(lmp)
13151276

1316-
# Get box dimensions
1317-
boxlo, boxhi, xy, yz, xz, periodicity, box_change = (
1318-
lmp.extract_box()
1277+
# Process frame and get statistics
1278+
stats = sm.process_monitoring_frame(
1279+
accumulator=accumulator, atoms=atoms, logger=self.logger
13191280
)
13201281

1321-
# Reshape positions to (natoms, 3)
1322-
positions = x.reshape(-1, 3)
1323-
1324-
# Create cell from box bounds
1325-
xlo, ylo, zlo = boxlo
1326-
xhi, yhi, zhi = boxhi
1327-
cell = [[xhi - xlo, 0, 0], [xy, yhi - ylo, 0], [xz, yz, zhi - zlo]]
1328-
1329-
# Create ASE Atoms object
1330-
atoms = Atoms(
1331-
numbers=atom_types, positions=positions, cell=cell, pbc=True
1332-
)
1333-
1334-
# Suppress sklearn convergence warnings during GPR fitting
1335-
with warnings.catch_warnings():
1336-
warnings.filterwarnings(
1337-
"ignore", category=Warning, module="sklearn"
1338-
)
1339-
is_consistent = accumulator.calculate_single_frame(atoms)
1340-
13411282
# Get detector statistics for logging
13421283
lambda_value = li + (lf - li) * current_step / total_steps
13431284
apparent_temp = t0 / lambda_value
13441285

1345-
# Extract data from accumulator using get_results()
1346-
distance = -1.0
1347-
mean_val = -1.0
1348-
std_val = -1.0
1349-
is_flagged = False # Whether this frame is flagged as anomalous
1350-
1351-
try:
1352-
frames_arr, vectors, distances_arr, flagged = (
1353-
accumulator.get_results()
1354-
)
1355-
# Get the latest distance (last element in distances array)
1356-
if len(distances_arr) > 0:
1357-
distance = distances_arr[-1]
1358-
1359-
# Check if current frame index is in the flagged array
1360-
current_frame_idx = (
1361-
len(frames_arr) - 1 if len(frames_arr) > 0 else -1
1362-
)
1363-
is_flagged = current_frame_idx in flagged
1364-
1365-
# Get detector statistics using common interface
1366-
# All detectors now have get_statistics() method
1367-
if hasattr(accumulator.detector, "get_statistics"):
1368-
# Stateful detectors (Adaptive, Cumulative, GPR) don't need arguments
1369-
# Stateless detectors (StdDev, Threshold) take historical_distances
1370-
try:
1371-
stats = accumulator.detector.get_statistics()
1372-
except TypeError:
1373-
# Stateless detector needs historical_distances argument
1374-
stats = accumulator.detector.get_statistics(
1375-
distances_arr
1376-
)
1377-
1378-
mean_val = (
1379-
stats.get("mean", -1.0)
1380-
if stats.get("mean") is not None
1381-
else -1.0
1382-
)
1383-
std_val = (
1384-
stats.get("std", -1.0)
1385-
if stats.get("std") is not None
1386-
else -1.0
1387-
)
1388-
except Exception as e:
1389-
self.logger.debug(f"Could not extract detector statistics: {e}")
1390-
13911286
# Write to monitoring file
13921287
monitor_out.write(
1393-
f"{block_idx} {current_step} {lambda_value:.6f} {apparent_temp:.2f} {int(is_flagged)} {distance:.6f} {mean_val:.6f} {std_val:.6f}\n"
1288+
f"{block_idx} {current_step} {lambda_value:.6f} {apparent_temp:.2f} "
1289+
f"{int(stats['is_flagged'])} {stats['distance']:.6f} "
1290+
f"{stats['mean']:.6f} {stats['std']:.6f}\n"
13941291
)
13951292
monitor_out.flush()
13961293

1397-
if is_flagged:
1294+
if stats["is_flagged"]:
13981295
self.logger.warning(
13991296
f"⚠️ STRUCTURAL CHANGE DETECTED at step {current_step} (λ={lambda_value:.4f}, T={apparent_temp:.2f}K) "
14001297
f"during forward sweep (iteration {iteration})"
@@ -1427,106 +1324,16 @@ def _reversible_scaling_forward(self, iteration=1):
14271324
self.logger.info(f"Structural monitoring data saved to {monitor_file}")
14281325

14291326
# Generate monitoring plot
1430-
try:
1431-
import matplotlib
1432-
1433-
matplotlib.use("Agg") # Non-interactive backend
1434-
import matplotlib.pyplot as plt
1435-
1436-
# Read monitoring data
1437-
data = np.loadtxt(monitor_file)
1438-
if len(data.shape) == 1: # Single row
1439-
data = data.reshape(1, -1)
1440-
1441-
block_idx_arr = data[:, 0]
1442-
step_arr = data[:, 1]
1443-
lambda_arr = data[:, 2]
1444-
temp_arr = data[:, 3]
1445-
is_flagged_arr = data[:, 4]
1446-
distance_arr = data[:, 5]
1447-
1448-
# Calculate cumulative mean and std
1449-
cumulative_mean = np.array(
1450-
[np.mean(distance_arr[: i + 1]) for i in range(len(distance_arr))]
1451-
)
1452-
cumulative_std = np.array(
1453-
[np.std(distance_arr[: i + 1]) for i in range(len(distance_arr))]
1454-
)
1455-
1456-
# Create plot
1457-
fig, ax = plt.subplots(figsize=(10, 6))
1458-
ax.plot(
1459-
step_arr,
1460-
distance_arr,
1461-
"o-",
1462-
label="Descriptor Distance",
1463-
linewidth=2,
1464-
markersize=6,
1465-
)
1466-
ax.plot(
1467-
step_arr,
1468-
cumulative_mean,
1469-
"k-",
1470-
label="Cumulative Mean",
1471-
linewidth=2,
1472-
)
1473-
ax.plot(
1474-
step_arr,
1475-
cumulative_mean + 2.0 * cumulative_std,
1476-
"--",
1477-
label="Mean + 2σ",
1478-
linewidth=1.5,
1479-
)
1480-
ax.plot(
1481-
step_arr,
1482-
cumulative_mean + 2.5 * cumulative_std,
1483-
"--",
1484-
label="Mean + 2.5σ",
1485-
linewidth=1.5,
1486-
)
1487-
ax.plot(
1488-
step_arr,
1489-
cumulative_mean + 3.0 * cumulative_std,
1490-
"--",
1491-
label="Mean + 3σ",
1492-
linewidth=1.5,
1493-
)
1494-
1495-
# Highlight flagged points (anomalies)
1496-
flagged_indices = np.where(is_flagged_arr > 0.5)[0]
1497-
if len(flagged_indices) > 0:
1498-
ax.scatter(
1499-
step_arr[flagged_indices],
1500-
distance_arr[flagged_indices],
1501-
c="red",
1502-
s=200,
1503-
marker="X",
1504-
edgecolors="darkred",
1505-
linewidth=2,
1506-
label="Flagged Anomalies",
1507-
zorder=5,
1508-
)
1509-
1510-
ax.set_xlabel("Step", fontsize=12)
1511-
ax.set_ylabel("Descriptor Distance", fontsize=12)
1512-
ax.set_title(
1513-
f"Structural Monitoring - Forward Sweep (Iteration {iteration})",
1514-
fontsize=14,
1515-
)
1516-
ax.legend(loc="best", fontsize=10)
1517-
ax.grid(True, alpha=0.3)
1518-
1519-
# Save plot
1520-
plot_file = os.path.join(
1521-
self.simfolder, f"structural_monitoring_forward_{iteration}.png"
1522-
)
1523-
plt.tight_layout()
1524-
plt.savefig(plot_file, dpi=150, bbox_inches="tight")
1525-
plt.close()
1327+
plot_file = os.path.join(
1328+
self.simfolder, f"structural_monitoring_forward_{iteration}.png"
1329+
)
1330+
success = sm.generate_monitoring_plot(
1331+
monitor_file=monitor_file, output_file=plot_file, iteration=iteration
1332+
)
1333+
if success:
15261334
self.logger.info(f"Monitoring plot saved to {plot_file}")
1527-
1528-
except Exception as e:
1529-
self.logger.warning(f"Failed to generate monitoring plot: {e}")
1335+
else:
1336+
self.logger.warning("Failed to generate monitoring plot")
15301337

15311338
# Merge all cycle files into final output
15321339
self.logger.info(f"Merging cycle files into ts.forward_{iteration}.dat")

0 commit comments

Comments
 (0)