Skip to content

Commit 28b92d6

Browse files
committed
Energy Calculations Using Different Methods
1 parent 928ff76 commit 28b92d6

File tree

1 file changed

+182
-0
lines changed

1 file changed

+182
-0
lines changed

pyxtal/interface/ase_opt2.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import torch
2+
3+
# --- Temporary fix for PyTorch 2.6 weights_only change ---
4+
from torch.serialization import add_safe_globals
5+
add_safe_globals([slice]) # allow 'slice' to be unpickled
6+
# ----------------------------------------------------------
7+
8+
from ase.io import read, write
9+
from fairchem.core import pretrained_mlip, FAIRChemCalculator
10+
import signal
11+
from time import time
12+
import numpy as np
13+
from ase.constraints import FixSymmetry
14+
from ase.filters import UnitCellFilter
15+
from ase.optimize.fire import FIRE
16+
import logging
17+
from pyxtal.optimize import WFS, DFS, QRS
18+
from pyxtal import pyxtal
19+
from pyxtal.util import get_pmg_dist
20+
import os
21+
from mace.calculators import mace_mp
22+
_cached_mace_mp = None
23+
from mace.calculators import mace_off
24+
25+
26+
def get_calculator(calculator):
27+
global _cached_mace_mp
28+
29+
if type(calculator) is str:
30+
if calculator == 'ANI':
31+
import torchani
32+
calc = torchani.models.ANI2x().ase()
33+
34+
elif calculator == 'MACE':
35+
if _cached_mace_mp is None:
36+
_cached_mace_mp = mace_mp(
37+
model='small',
38+
dispersion=True,
39+
device='cpu'
40+
)
41+
calc = _cached_mace_mp
42+
43+
elif calculator == 'MACEOFF':
44+
calc = mace_off(model='medium', device='cpu')
45+
46+
elif calculator == 'FAIRChem':
47+
# Initialize FAIRChem
48+
predictor = pretrained_mlip.get_predict_unit("uma-s-1p1", device="cpu")
49+
calc = FAIRChemCalculator(predictor, task_name="omc")
50+
51+
else:
52+
raise ValueError(f"Unknown calculator: {calculator}")
53+
54+
else:
55+
calc = calculator
56+
57+
return calc
58+
59+
class ASE_optimizer:
60+
"""
61+
This is a ASE optimizer to perform oragnic crystal structure optimization.
62+
We assume that the geometry has been well optimized by classical FF.
63+
64+
Args:
65+
struc: pyxtal object
66+
calculator (str): 'ANI', 'MACE'
67+
opt_lat (bool): to opt lattice or not
68+
log_file (str): output file
69+
"""
70+
71+
def __init__(self, struc, calculator="FAIRChem", opt_lat=True, logfile="ase_log_ebdc_local"):
72+
self.structure = struc
73+
self.calculator = get_calculator(calculator)
74+
self.opt_lat = opt_lat
75+
self.stress = None
76+
self.forces = None
77+
self.optimized = True
78+
self.positions = None
79+
self.cell = None
80+
self.cputime = 0
81+
self.logfile = logfile
82+
83+
def run(self, fmax_target=0.01):
84+
t0 = time()
85+
s = self.structure.to_ase(resort=False)
86+
s.set_constraint(FixSymmetry(s))
87+
s.set_calculator(self.calculator)
88+
89+
obj = UnitCellFilter(s) if self.opt_lat else s
90+
dyn = FIRE(obj, a=0.01, logfile=self.logfile)
91+
92+
# <-- Key line: no 'steps' argument
93+
dyn.run(fmax=fmax_target)
94+
95+
if self.opt_lat:
96+
self.structure.lattice.set_matrix(s.get_cell())
97+
98+
positions = s.get_scaled_positions()
99+
count = 0
100+
for _i, site in enumerate(self.structure.mol_sites):
101+
coords0, _ = site._get_coords_and_species(first=True)
102+
coords1 = positions[count : count + len(site.molecule.mol)]
103+
for j, coor in enumerate(coords1):
104+
diff = coor - coords0[j]
105+
diff -= np.round(diff)
106+
abs_diff = np.dot(diff, s.get_cell())
107+
if abs(np.linalg.norm(abs_diff)) < 2.0:
108+
coords1[j] = coords0[j] + diff
109+
else:
110+
print(coords1[j], coords1[j], np.linalg.norm(abs_diff))
111+
import sys; sys.exit()
112+
site.update(coords1, self.structure.lattice)
113+
count += len(site.molecule.mol) * site.wp.multiplicity
114+
115+
self.structure.optimize_lattice()
116+
self.structure.energy = s.get_potential_energy()
117+
self.cell = s.get_cell()
118+
119+
s.set_calculator()
120+
s.set_constraint()
121+
self.cputime = time() - t0
122+
self.optimized = bool(getattr(dyn, "converged", False))
123+
124+
125+
#("/Users/mmukta/Downloads/HOF-EBDC_aka_ZJU-HOF-60.cif", "O=C(O)c2cc(C#Cc1cc(C(=O)O)cc(C(=O)O)c1)cc(C(=O)O)c2")
126+
#("/Users/mmukta/Downloads/HOF-BDDC_aka_ZJU-HOF-62.cif", "O=C(O)c2cc(C#CC#Cc1cc(C(=O)O)cc(C(=O)O)c1)cc(C(=O)O)c2")
127+
#("/Users/mmukta/Downloads/HOF-1a.cif", "Nc8nc(N)nc(c7ccc(C(c2ccc(c1nc(N)nc(N)n1)cc2)(c4ccc(c3nc(N)nc(N)n3)cc4)c6ccc(c5nc(N)nc(N)n5)cc6)cc7)n8")
128+
129+
if __name__ == "__main__":
130+
import os, warnings
131+
from pyxtal.db import database
132+
warnings.filterwarnings("ignore")
133+
134+
work_dir = "tmp"
135+
if not os.path.exists(work_dir):
136+
os.makedirs(work_dir)
137+
138+
#db = database("pyxtal/database/test.db")
139+
#struc = db.get_pyxtal("ACSALA")
140+
data = [
141+
("/Users/mmukta/Desktop/Cocrystal/ebdc-local-maceOff1500.cif", "O=C(O)c2cc(C#Cc1cc(C(=O)O)cc(C(=O)O)c1)cc(C(=O)O)c2")
142+
]
143+
144+
for d in data:
145+
cif, smiles = d
146+
c = pyxtal(molecular=True)
147+
c.from_seed(cif, molecules=[smiles+'.smi'])
148+
pmg0 = c.to_pymatgen()
149+
if c.has_special_site():
150+
c1 = c.to_subgroup(); print(c1)
151+
pmg = c1.to_pymatgen()
152+
if get_pmg_dist(pmg0, pmg) > 0.1:
153+
print("The reference structure is not a valid subgroup.")
154+
m = c1.mol_sites[0]
155+
m.rotate(ax_id=2, angle=180)
156+
pmg = c1.to_pymatgen()
157+
print("Distance after flip", get_pmg_dist(pmg0, pmg))
158+
c = c1
159+
pmg = c.to_pymatgen()
160+
else:
161+
pmg = pmg0
162+
calc = ASE_optimizer(c)
163+
print(calc.structure.lattice)
164+
#calc.run(steps=1500)
165+
calc.run(fmax_target=0.1)
166+
print(calc.structure.energy)
167+
print(calc.structure.lattice)
168+
'''
169+
calc.structure.to_file("maceOff/ebdc-local.cif")
170+
from pymatgen.core import Structure
171+
# Load CIF
172+
structure = Structure.from_file("maceOff/ebdc-local.cif")
173+
# Get density in g/cm³
174+
print("Density:", structure.density, "g/cm³")
175+
total_molecules = 0
176+
print("Molecular sites and multiplicities:")
177+
for i, site in enumerate(c.mol_sites):
178+
print(f"Site {i+1}: Wyckoff multiplicity = {site.wp.multiplicity}")
179+
total_molecules += site.wp.multiplicity
180+
181+
print(f"\nTotal molecules per unit cell: {total_molecules}")
182+
'''

0 commit comments

Comments
 (0)