3838import calphy .helpers as ph
3939from calphy .errors import *
4040from calphy .input import generate_metadata
41+ import calphy .structural_monitoring as sm
4142
42- # Import FrameAccumulator for structural monitoring
43- try :
44- from structid import FrameAccumulator
45- from structid .descriptors import SteinhardtDescriptor
46- from structid .detectors import (
47- GPRDetector ,
48- AdaptiveZScoreDetector ,
49- CumulativeMeanDetector ,
50- )
51-
52- STRUCTID_AVAILABLE = True
53- except ImportError :
54- STRUCTID_AVAILABLE = False
43+ # Check if structid is available
44+ STRUCTID_AVAILABLE = sm .STRUCTID_AVAILABLE
5545
5646
5747class Phase :
@@ -1069,26 +1059,20 @@ def _reversible_scaling_forward(self, iteration=1):
10691059 f"Temperature interval for monitoring: { self .calc .structural_monitoring .temperature_interval } K"
10701060 )
10711061
1072- # Create descriptor and detector
1073- descriptor = SteinhardtDescriptor (
1074- l_values = tuple ( self .calc .structural_monitoring .l_values ) ,
1062+ # Initialize accumulator using structural_monitoring module
1063+ accumulator = sm . initialize_accumulator (
1064+ l_values = self .calc .structural_monitoring .l_values ,
10751065 cutoff = self .calc .structural_monitoring .cutoff ,
1066+ detector_type = "cumulative" ,
1067+ threshold = 3 ,
10761068 )
1077- detector = CumulativeMeanDetector (threshold = 3 )
1078-
1079- # Initialize accumulator
1080- accumulator = FrameAccumulator (descriptor = descriptor , detector = detector )
10811069
10821070 # Set up dump for equilibration trajectory to pre-train
1083- pretrain_dump = os .path .join (self .simfolder , "_tmp_pretrain_traj.dump" )
1084- pretrain_dump_interval = (
1085- 100 # Fixed interval for equilibration pre-training
1086- )
1087- lmp .command (
1088- f"dump _pretrain all custom { pretrain_dump_interval } { pretrain_dump } id type x y z"
1071+ pretrain_dump = sm .setup_pretraining_dump (
1072+ lmp = lmp , simfolder = self .simfolder , dump_interval = 100
10891073 )
10901074 self .logger .info (
1091- f "Dumping equilibration trajectory every { pretrain_dump_interval } steps for pre-training"
1075+ "Dumping equilibration trajectory every 100 steps for pre-training"
10921076 )
10931077
10941078 self .logger .info (f"Starting equilibration with constrained com: { iteration } " )
@@ -1099,35 +1083,16 @@ def _reversible_scaling_forward(self, iteration=1):
10991083 if accumulator is not None and pretrain_dump is not None :
11001084 lmp .command ("undump _pretrain" )
11011085
1102- if os .path .exists (pretrain_dump ):
1103- from ase .io import read
1104-
1105- try :
1106- self .logger .info (
1107- "Pre-training FrameAccumulator on equilibration trajectory"
1108- )
1109- # Read all frames from equilibration
1110- traj = read (pretrain_dump , format = "lammps-dump-text" , index = ":" )
1111- self .logger .info (
1112- f"Read { len (traj )} frames from equilibration trajectory"
1113- )
1086+ success = sm .pretrain_accumulator (
1087+ accumulator = accumulator , pretrain_dump = pretrain_dump , logger = self .logger
1088+ )
11141089
1115- # Pre-train on all frames
1116- with warnings .catch_warnings ():
1117- warnings .filterwarnings (
1118- "ignore" , category = Warning , module = "sklearn"
1119- )
1120- for atoms in traj :
1121- accumulator .calculate_single_frame (atoms )
1090+ if not success :
1091+ accumulator = None
11221092
1123- self .logger .info ("FrameAccumulator pre-trained successfully" )
1124- except Exception as e :
1125- self .logger .warning (f"Failed to pre-train FrameAccumulator: { e } " )
1126- accumulator = None
1127- finally :
1128- # Clean up trajectory file
1129- if os .path .exists (pretrain_dump ):
1130- os .remove (pretrain_dump )
1093+ # Clean up trajectory file
1094+ if os .path .exists (pretrain_dump ):
1095+ os .remove (pretrain_dump )
11311096
11321097 lmp .command ("variable flambda equal ramp(${li},${lf})" )
11331098 lmp .command ("variable blambda equal ramp(${lf},${li})" )
@@ -1305,96 +1270,28 @@ def _reversible_scaling_forward(self, iteration=1):
13051270 # Unfix to avoid duplicates at boundaries
13061271 lmp .command ("unfix f3" )
13071272
1308- # Get atomic data directly from LAMMPS instead of writing/reading file
1273+ # Get atomic data directly from LAMMPS
13091274 try :
1310- from ase import Atoms
1311-
1312- # Gather atomic positions and types
1313- x = lmp .gather_atoms ("x" , 3 ) # positions (natoms × 3)
1314- atom_types = lmp .gather_atoms ("type" , 0 ) # atom types (natoms,)
1275+ atoms = sm .extract_atoms_from_lammps (lmp )
13151276
1316- # Get box dimensions
1317- boxlo , boxhi , xy , yz , xz , periodicity , box_change = (
1318- lmp . extract_box ()
1277+ # Process frame and get statistics
1278+ stats = sm . process_monitoring_frame (
1279+ accumulator = accumulator , atoms = atoms , logger = self . logger
13191280 )
13201281
1321- # Reshape positions to (natoms, 3)
1322- positions = x .reshape (- 1 , 3 )
1323-
1324- # Create cell from box bounds
1325- xlo , ylo , zlo = boxlo
1326- xhi , yhi , zhi = boxhi
1327- cell = [[xhi - xlo , 0 , 0 ], [xy , yhi - ylo , 0 ], [xz , yz , zhi - zlo ]]
1328-
1329- # Create ASE Atoms object
1330- atoms = Atoms (
1331- numbers = atom_types , positions = positions , cell = cell , pbc = True
1332- )
1333-
1334- # Suppress sklearn convergence warnings during GPR fitting
1335- with warnings .catch_warnings ():
1336- warnings .filterwarnings (
1337- "ignore" , category = Warning , module = "sklearn"
1338- )
1339- is_consistent = accumulator .calculate_single_frame (atoms )
1340-
13411282 # Get detector statistics for logging
13421283 lambda_value = li + (lf - li ) * current_step / total_steps
13431284 apparent_temp = t0 / lambda_value
13441285
1345- # Extract data from accumulator using get_results()
1346- distance = - 1.0
1347- mean_val = - 1.0
1348- std_val = - 1.0
1349- is_flagged = False # Whether this frame is flagged as anomalous
1350-
1351- try :
1352- frames_arr , vectors , distances_arr , flagged = (
1353- accumulator .get_results ()
1354- )
1355- # Get the latest distance (last element in distances array)
1356- if len (distances_arr ) > 0 :
1357- distance = distances_arr [- 1 ]
1358-
1359- # Check if current frame index is in the flagged array
1360- current_frame_idx = (
1361- len (frames_arr ) - 1 if len (frames_arr ) > 0 else - 1
1362- )
1363- is_flagged = current_frame_idx in flagged
1364-
1365- # Get detector statistics using common interface
1366- # All detectors now have get_statistics() method
1367- if hasattr (accumulator .detector , "get_statistics" ):
1368- # Stateful detectors (Adaptive, Cumulative, GPR) don't need arguments
1369- # Stateless detectors (StdDev, Threshold) take historical_distances
1370- try :
1371- stats = accumulator .detector .get_statistics ()
1372- except TypeError :
1373- # Stateless detector needs historical_distances argument
1374- stats = accumulator .detector .get_statistics (
1375- distances_arr
1376- )
1377-
1378- mean_val = (
1379- stats .get ("mean" , - 1.0 )
1380- if stats .get ("mean" ) is not None
1381- else - 1.0
1382- )
1383- std_val = (
1384- stats .get ("std" , - 1.0 )
1385- if stats .get ("std" ) is not None
1386- else - 1.0
1387- )
1388- except Exception as e :
1389- self .logger .debug (f"Could not extract detector statistics: { e } " )
1390-
13911286 # Write to monitoring file
13921287 monitor_out .write (
1393- f"{ block_idx } { current_step } { lambda_value :.6f} { apparent_temp :.2f} { int (is_flagged )} { distance :.6f} { mean_val :.6f} { std_val :.6f} \n "
1288+ f"{ block_idx } { current_step } { lambda_value :.6f} { apparent_temp :.2f} "
1289+ f"{ int (stats ['is_flagged' ])} { stats ['distance' ]:.6f} "
1290+ f"{ stats ['mean' ]:.6f} { stats ['std' ]:.6f} \n "
13941291 )
13951292 monitor_out .flush ()
13961293
1397- if is_flagged :
1294+ if stats [ " is_flagged" ] :
13981295 self .logger .warning (
13991296 f"⚠️ STRUCTURAL CHANGE DETECTED at step { current_step } (λ={ lambda_value :.4f} , T={ apparent_temp :.2f} K) "
14001297 f"during forward sweep (iteration { iteration } )"
@@ -1427,106 +1324,16 @@ def _reversible_scaling_forward(self, iteration=1):
14271324 self .logger .info (f"Structural monitoring data saved to { monitor_file } " )
14281325
14291326 # Generate monitoring plot
1430- try :
1431- import matplotlib
1432-
1433- matplotlib .use ("Agg" ) # Non-interactive backend
1434- import matplotlib .pyplot as plt
1435-
1436- # Read monitoring data
1437- data = np .loadtxt (monitor_file )
1438- if len (data .shape ) == 1 : # Single row
1439- data = data .reshape (1 , - 1 )
1440-
1441- block_idx_arr = data [:, 0 ]
1442- step_arr = data [:, 1 ]
1443- lambda_arr = data [:, 2 ]
1444- temp_arr = data [:, 3 ]
1445- is_flagged_arr = data [:, 4 ]
1446- distance_arr = data [:, 5 ]
1447-
1448- # Calculate cumulative mean and std
1449- cumulative_mean = np .array (
1450- [np .mean (distance_arr [: i + 1 ]) for i in range (len (distance_arr ))]
1451- )
1452- cumulative_std = np .array (
1453- [np .std (distance_arr [: i + 1 ]) for i in range (len (distance_arr ))]
1454- )
1455-
1456- # Create plot
1457- fig , ax = plt .subplots (figsize = (10 , 6 ))
1458- ax .plot (
1459- step_arr ,
1460- distance_arr ,
1461- "o-" ,
1462- label = "Descriptor Distance" ,
1463- linewidth = 2 ,
1464- markersize = 6 ,
1465- )
1466- ax .plot (
1467- step_arr ,
1468- cumulative_mean ,
1469- "k-" ,
1470- label = "Cumulative Mean" ,
1471- linewidth = 2 ,
1472- )
1473- ax .plot (
1474- step_arr ,
1475- cumulative_mean + 2.0 * cumulative_std ,
1476- "--" ,
1477- label = "Mean + 2σ" ,
1478- linewidth = 1.5 ,
1479- )
1480- ax .plot (
1481- step_arr ,
1482- cumulative_mean + 2.5 * cumulative_std ,
1483- "--" ,
1484- label = "Mean + 2.5σ" ,
1485- linewidth = 1.5 ,
1486- )
1487- ax .plot (
1488- step_arr ,
1489- cumulative_mean + 3.0 * cumulative_std ,
1490- "--" ,
1491- label = "Mean + 3σ" ,
1492- linewidth = 1.5 ,
1493- )
1494-
1495- # Highlight flagged points (anomalies)
1496- flagged_indices = np .where (is_flagged_arr > 0.5 )[0 ]
1497- if len (flagged_indices ) > 0 :
1498- ax .scatter (
1499- step_arr [flagged_indices ],
1500- distance_arr [flagged_indices ],
1501- c = "red" ,
1502- s = 200 ,
1503- marker = "X" ,
1504- edgecolors = "darkred" ,
1505- linewidth = 2 ,
1506- label = "Flagged Anomalies" ,
1507- zorder = 5 ,
1508- )
1509-
1510- ax .set_xlabel ("Step" , fontsize = 12 )
1511- ax .set_ylabel ("Descriptor Distance" , fontsize = 12 )
1512- ax .set_title (
1513- f"Structural Monitoring - Forward Sweep (Iteration { iteration } )" ,
1514- fontsize = 14 ,
1515- )
1516- ax .legend (loc = "best" , fontsize = 10 )
1517- ax .grid (True , alpha = 0.3 )
1518-
1519- # Save plot
1520- plot_file = os .path .join (
1521- self .simfolder , f"structural_monitoring_forward_{ iteration } .png"
1522- )
1523- plt .tight_layout ()
1524- plt .savefig (plot_file , dpi = 150 , bbox_inches = "tight" )
1525- plt .close ()
1327+ plot_file = os .path .join (
1328+ self .simfolder , f"structural_monitoring_forward_{ iteration } .png"
1329+ )
1330+ success = sm .generate_monitoring_plot (
1331+ monitor_file = monitor_file , output_file = plot_file , iteration = iteration
1332+ )
1333+ if success :
15261334 self .logger .info (f"Monitoring plot saved to { plot_file } " )
1527-
1528- except Exception as e :
1529- self .logger .warning (f"Failed to generate monitoring plot: { e } " )
1335+ else :
1336+ self .logger .warning ("Failed to generate monitoring plot" )
15301337
15311338 # Merge all cycle files into final output
15321339 self .logger .info (f"Merging cycle files into ts.forward_{ iteration } .dat" )
0 commit comments