1+ import torch
2+
3+ # --- Temporary fix for PyTorch 2.6 weights_only change ---
4+ from torch .serialization import add_safe_globals
5+ add_safe_globals ([slice ]) # allow 'slice' to be unpickled
6+ # ----------------------------------------------------------
7+
8+ from ase .io import read , write
9+ from fairchem .core import pretrained_mlip , FAIRChemCalculator
10+ import signal
11+ from time import time
12+ import numpy as np
13+ from ase .constraints import FixSymmetry
14+ from ase .filters import UnitCellFilter
15+ from ase .optimize .fire import FIRE
16+ import logging
17+ from pyxtal .optimize import WFS , DFS , QRS
18+ from pyxtal import pyxtal
19+ from pyxtal .util import get_pmg_dist
20+ import os
21+ from mace .calculators import mace_mp
22+ _cached_mace_mp = None
23+ from mace .calculators import mace_off
24+
25+
26+ def get_calculator (calculator ):
27+ global _cached_mace_mp
28+
29+ if type (calculator ) is str :
30+ if calculator == 'ANI' :
31+ import torchani
32+ calc = torchani .models .ANI2x ().ase ()
33+
34+ elif calculator == 'MACE' :
35+ if _cached_mace_mp is None :
36+ _cached_mace_mp = mace_mp (
37+ model = 'small' ,
38+ dispersion = True ,
39+ device = 'cpu'
40+ )
41+ calc = _cached_mace_mp
42+
43+ elif calculator == 'MACEOFF' :
44+ calc = mace_off (model = 'medium' , device = 'cpu' )
45+
46+ elif calculator == 'FAIRChem' :
47+ # Initialize FAIRChem
48+ predictor = pretrained_mlip .get_predict_unit ("uma-s-1p1" , device = "cpu" )
49+ calc = FAIRChemCalculator (predictor , task_name = "omc" )
50+
51+ else :
52+ raise ValueError (f"Unknown calculator: { calculator } " )
53+
54+ else :
55+ calc = calculator
56+
57+ return calc
58+
59+ class ASE_optimizer :
60+ """
61+ This is a ASE optimizer to perform oragnic crystal structure optimization.
62+ We assume that the geometry has been well optimized by classical FF.
63+
64+ Args:
65+ struc: pyxtal object
66+ calculator (str): 'ANI', 'MACE'
67+ opt_lat (bool): to opt lattice or not
68+ log_file (str): output file
69+ """
70+
71+ def __init__ (self , struc , calculator = "FAIRChem" , opt_lat = True , logfile = "ase_log_ebdc_local" ):
72+ self .structure = struc
73+ self .calculator = get_calculator (calculator )
74+ self .opt_lat = opt_lat
75+ self .stress = None
76+ self .forces = None
77+ self .optimized = True
78+ self .positions = None
79+ self .cell = None
80+ self .cputime = 0
81+ self .logfile = logfile
82+
83+ def run (self , fmax_target = 0.01 ):
84+ t0 = time ()
85+ s = self .structure .to_ase (resort = False )
86+ s .set_constraint (FixSymmetry (s ))
87+ s .set_calculator (self .calculator )
88+
89+ obj = UnitCellFilter (s ) if self .opt_lat else s
90+ dyn = FIRE (obj , a = 0.01 , logfile = self .logfile )
91+
92+ # <-- Key line: no 'steps' argument
93+ dyn .run (fmax = fmax_target )
94+
95+ if self .opt_lat :
96+ self .structure .lattice .set_matrix (s .get_cell ())
97+
98+ positions = s .get_scaled_positions ()
99+ count = 0
100+ for _i , site in enumerate (self .structure .mol_sites ):
101+ coords0 , _ = site ._get_coords_and_species (first = True )
102+ coords1 = positions [count : count + len (site .molecule .mol )]
103+ for j , coor in enumerate (coords1 ):
104+ diff = coor - coords0 [j ]
105+ diff -= np .round (diff )
106+ abs_diff = np .dot (diff , s .get_cell ())
107+ if abs (np .linalg .norm (abs_diff )) < 2.0 :
108+ coords1 [j ] = coords0 [j ] + diff
109+ else :
110+ print (coords1 [j ], coords1 [j ], np .linalg .norm (abs_diff ))
111+ import sys ; sys .exit ()
112+ site .update (coords1 , self .structure .lattice )
113+ count += len (site .molecule .mol ) * site .wp .multiplicity
114+
115+ self .structure .optimize_lattice ()
116+ self .structure .energy = s .get_potential_energy ()
117+ self .cell = s .get_cell ()
118+
119+ s .set_calculator ()
120+ s .set_constraint ()
121+ self .cputime = time () - t0
122+ self .optimized = bool (getattr (dyn , "converged" , False ))
123+
124+
125+ #("/Users/mmukta/Downloads/HOF-EBDC_aka_ZJU-HOF-60.cif", "O=C(O)c2cc(C#Cc1cc(C(=O)O)cc(C(=O)O)c1)cc(C(=O)O)c2")
126+ #("/Users/mmukta/Downloads/HOF-BDDC_aka_ZJU-HOF-62.cif", "O=C(O)c2cc(C#CC#Cc1cc(C(=O)O)cc(C(=O)O)c1)cc(C(=O)O)c2")
127+ #("/Users/mmukta/Downloads/HOF-1a.cif", "Nc8nc(N)nc(c7ccc(C(c2ccc(c1nc(N)nc(N)n1)cc2)(c4ccc(c3nc(N)nc(N)n3)cc4)c6ccc(c5nc(N)nc(N)n5)cc6)cc7)n8")
128+
129+ if __name__ == "__main__" :
130+ import os , warnings
131+ from pyxtal .db import database
132+ warnings .filterwarnings ("ignore" )
133+
134+ work_dir = "tmp"
135+ if not os .path .exists (work_dir ):
136+ os .makedirs (work_dir )
137+
138+ #db = database("pyxtal/database/test.db")
139+ #struc = db.get_pyxtal("ACSALA")
140+ data = [
141+ ("/Users/mmukta/Desktop/Cocrystal/ebdc-local-maceOff1500.cif" , "O=C(O)c2cc(C#Cc1cc(C(=O)O)cc(C(=O)O)c1)cc(C(=O)O)c2" )
142+ ]
143+
144+ for d in data :
145+ cif , smiles = d
146+ c = pyxtal (molecular = True )
147+ c .from_seed (cif , molecules = [smiles + '.smi' ])
148+ pmg0 = c .to_pymatgen ()
149+ if c .has_special_site ():
150+ c1 = c .to_subgroup (); print (c1 )
151+ pmg = c1 .to_pymatgen ()
152+ if get_pmg_dist (pmg0 , pmg ) > 0.1 :
153+ print ("The reference structure is not a valid subgroup." )
154+ m = c1 .mol_sites [0 ]
155+ m .rotate (ax_id = 2 , angle = 180 )
156+ pmg = c1 .to_pymatgen ()
157+ print ("Distance after flip" , get_pmg_dist (pmg0 , pmg ))
158+ c = c1
159+ pmg = c .to_pymatgen ()
160+ else :
161+ pmg = pmg0
162+ calc = ASE_optimizer (c )
163+ print (calc .structure .lattice )
164+ #calc.run(steps=1500)
165+ calc .run (fmax_target = 0.1 )
166+ print (calc .structure .energy )
167+ print (calc .structure .lattice )
168+ '''
169+ calc.structure.to_file("maceOff/ebdc-local.cif")
170+ from pymatgen.core import Structure
171+ # Load CIF
172+ structure = Structure.from_file("maceOff/ebdc-local.cif")
173+ # Get density in g/cm³
174+ print("Density:", structure.density, "g/cm³")
175+ total_molecules = 0
176+ print("Molecular sites and multiplicities:")
177+ for i, site in enumerate(c.mol_sites):
178+ print(f"Site {i+1}: Wyckoff multiplicity = {site.wp.multiplicity}")
179+ total_molecules += site.wp.multiplicity
180+
181+ print(f"\n Total molecules per unit cell: {total_molecules}")
182+ '''
0 commit comments