Skip to content

Commit 6652bb8

Browse files
Add CHGNet (#1053)
* Add CHGNet * 1228_ver_1_1 * add demo * fix: refine code style --------- Co-authored-by: zhangzhimin0830 <[email protected]>
1 parent ee2d77c commit 6652bb8

21 files changed

+5244
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from importlib.metadata import PackageNotFoundError
5+
from importlib.metadata import version
6+
from typing import Literal
7+
8+
try:
9+
__version__ = version(__name__)
10+
except PackageNotFoundError:
11+
__version__ = "unknown"
12+
TrainTask = Literal["ef", "efs", "efsm"]
13+
PredTask = Literal["e", "ef", "em", "efs", "efsm"]
14+
ROOT = os.path.dirname(os.path.dirname(__file__))

jointContribution/CHGNet/chgnet/data/__init__.py

Whitespace-only changes.

jointContribution/CHGNet/chgnet/data/dataset.py

Lines changed: 847 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from __future__ import annotations
2+
3+
from chgnet.graph.converter import CrystalGraphConverter # noqa
4+
from chgnet.graph.crystalgraph import CrystalGraph # noqa
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
from __future__ import annotations
2+
3+
import gc
4+
import sys
5+
import warnings
6+
from typing import TYPE_CHECKING
7+
8+
import numpy as np
9+
import paddle
10+
from chgnet.graph.crystalgraph import CrystalGraph
11+
from chgnet.graph.graph import Graph
12+
from chgnet.graph.graph import Node
13+
14+
if TYPE_CHECKING:
15+
from typing import Literal
16+
17+
from pymatgen.core import Structure
18+
from typing_extensions import Self
19+
# try:
20+
# from chgnet.graph.cygraph import make_graph
21+
# except (ImportError, AttributeError):
22+
# make_graph = None
23+
make_graph = None
24+
DTYPE = "float32"
25+
26+
27+
class CrystalGraphConverter(paddle.nn.Layer):
28+
"""Convert a pymatgen.core.Structure to a CrystalGraph
29+
The CrystalGraph dataclass stores essential field to make sure that
30+
gradients like force and stress can be calculated through back-propagation later.
31+
"""
32+
33+
make_graph = None
34+
35+
def __init__(
36+
self,
37+
*,
38+
atom_graph_cutoff: float = 6,
39+
bond_graph_cutoff: float = 3,
40+
algorithm: Literal["legacy", "fast"] = "fast",
41+
on_isolated_atoms: Literal["ignore", "warn", "error"] = "error",
42+
verbose: bool = False,
43+
) -> None:
44+
"""Initialize the Crystal Graph Converter.
45+
46+
Args:
47+
atom_graph_cutoff (float): cutoff radius to search for neighboring atom in
48+
atom_graph. Default = 5.
49+
bond_graph_cutoff (float): bond length threshold to include bond in
50+
bond_graph. Default = 3.
51+
algorithm ('legacy' | 'fast'): algorithm to use for converting graphs.
52+
'legacy': python implementation of graph creation
53+
'fast': C implementation of graph creation, this is faster,
54+
but will need the cygraph.c file correctly compiled from pip install
55+
Default = 'fast'
56+
on_isolated_atoms ('ignore' | 'warn' | 'error'): how to handle Structures
57+
with isolated atoms.
58+
Default = 'error'
59+
verbose (bool): whether to print the CrystalGraphConverter
60+
initialization message. Default = False.
61+
"""
62+
super().__init__()
63+
self.atom_graph_cutoff = atom_graph_cutoff
64+
self.bond_graph_cutoff = (
65+
atom_graph_cutoff if bond_graph_cutoff is None else bond_graph_cutoff
66+
)
67+
self.on_isolated_atoms = on_isolated_atoms
68+
self.create_graph = self._create_graph_legacy
69+
self.algorithm = "legacy"
70+
if algorithm == "fast":
71+
if make_graph is not None:
72+
self.create_graph = self._create_graph_fast
73+
self.algorithm = "fast"
74+
else:
75+
warnings.warn(
76+
"`fast` algorithm is not available, using `legacy`",
77+
UserWarning,
78+
stacklevel=1,
79+
)
80+
elif algorithm != "legacy":
81+
warnings.warn(
82+
f"Unknown algorithm={algorithm!r}, using `legacy`",
83+
UserWarning,
84+
stacklevel=1,
85+
)
86+
if verbose:
87+
print(self)
88+
89+
def __repr__(self) -> str:
90+
"""String representation of the CrystalGraphConverter."""
91+
atom_graph_cutoff = self.atom_graph_cutoff
92+
bond_graph_cutoff = self.bond_graph_cutoff
93+
algorithm = self.algorithm
94+
cls_name = type(self).__name__
95+
return f"{cls_name}(algorithm={algorithm!r}, atom_graph_cutoff={atom_graph_cutoff!r}, bond_graph_cutoff={bond_graph_cutoff!r})"
96+
97+
def forward(self, structure: Structure, graph_id=None, mp_id=None) -> CrystalGraph:
98+
"""Convert a structure, return a CrystalGraph.
99+
100+
Args:
101+
structure (pymatgen.core.Structure): structure to convert
102+
graph_id (str): an id to keep track of this crystal graph
103+
Default = None
104+
mp_id (str): Materials Project id of this structure
105+
Default = None
106+
107+
Return:
108+
CrystalGraph that is ready to use by CHGNet
109+
"""
110+
n_atoms = len(structure)
111+
data = [site.specie.Z for site in structure]
112+
atomic_number = paddle.to_tensor(data, dtype="int32", stop_gradient=not False)
113+
atom_frac_coord = paddle.to_tensor(
114+
data=structure.frac_coords, dtype=DTYPE, stop_gradient=not True
115+
)
116+
lattice = paddle.to_tensor(
117+
data=structure.lattice.matrix, dtype=DTYPE, stop_gradient=not True
118+
)
119+
center_index, neighbor_index, image, distance = structure.get_neighbor_list(
120+
r=self.atom_graph_cutoff, sites=structure.sites, numerical_tol=1e-08
121+
)
122+
graph = self.create_graph(
123+
n_atoms, center_index, neighbor_index, image, distance
124+
)
125+
atom_graph, directed2undirected = graph.adjacency_list()
126+
atom_graph = paddle.to_tensor(data=atom_graph, dtype="int32")
127+
directed2undirected = paddle.to_tensor(data=directed2undirected, dtype="int32")
128+
try:
129+
bond_graph, undirected2directed = graph.line_graph_adjacency_list(
130+
cutoff=self.bond_graph_cutoff
131+
)
132+
except Exception as exc:
133+
structure.to(filename="bond_graph_error.cif")
134+
raise RuntimeError(
135+
f"Failed creating bond graph for {graph_id}, check bond_graph_error.cif"
136+
) from exc
137+
bond_graph = paddle.to_tensor(data=bond_graph, dtype="int32")
138+
undirected2directed = paddle.to_tensor(data=undirected2directed, dtype="int32")
139+
n_isolated_atoms = len({*range(n_atoms)} - {*center_index})
140+
if n_isolated_atoms:
141+
atom_graph_cutoff = self.atom_graph_cutoff
142+
msg = f"Structure graph_id={graph_id!r} has {n_isolated_atoms} isolated atom(s) with atom_graph_cutoff={atom_graph_cutoff!r}. CHGNet calculation will likely go wrong"
143+
if self.on_isolated_atoms == "error":
144+
raise ValueError(msg)
145+
elif self.on_isolated_atoms == "warn":
146+
print(msg, file=sys.stderr)
147+
return CrystalGraph(
148+
atomic_number=atomic_number,
149+
atom_frac_coord=atom_frac_coord,
150+
atom_graph=atom_graph,
151+
neighbor_image=paddle.to_tensor(data=image, dtype=DTYPE),
152+
directed2undirected=directed2undirected,
153+
undirected2directed=undirected2directed,
154+
bond_graph=bond_graph,
155+
lattice=lattice,
156+
graph_id=graph_id,
157+
mp_id=mp_id,
158+
composition=structure.composition.formula,
159+
atom_graph_cutoff=self.atom_graph_cutoff,
160+
bond_graph_cutoff=self.bond_graph_cutoff,
161+
)
162+
163+
@staticmethod
164+
def _create_graph_legacy(
165+
n_atoms: int,
166+
center_index: np.ndarray,
167+
neighbor_index: np.ndarray,
168+
image: np.ndarray,
169+
distance: np.ndarray,
170+
) -> Graph:
171+
"""Given structure information, create a Graph structure to be used to
172+
create Crystal_Graph using pure python implementation.
173+
174+
Args:
175+
n_atoms (int): the number of atoms in the structure
176+
center_index (np.ndarray): np array of indices of center atoms.
177+
[num_undirected_bonds]
178+
neighbor_index (np.ndarray): np array of indices of neighbor atoms.
179+
[num_undirected_bonds]
180+
image (np.ndarray): np array of images for each edge.
181+
[num_undirected_bonds, 3]
182+
distance (np.ndarray): np array of distances.
183+
[num_undirected_bonds]
184+
185+
Return:
186+
Graph data structure used to create Crystal_Graph object
187+
"""
188+
graph = Graph([Node(index=idx) for idx in range(n_atoms)])
189+
for ii, jj, img, dist in zip(
190+
center_index, neighbor_index, image, distance, strict=True
191+
):
192+
graph.add_edge(center_index=ii, neighbor_index=jj, image=img, distance=dist)
193+
return graph
194+
195+
@staticmethod
196+
def _create_graph_fast(
197+
n_atoms: int,
198+
center_index: np.ndarray,
199+
neighbor_index: np.ndarray,
200+
image: np.ndarray,
201+
distance: np.ndarray,
202+
) -> Graph:
203+
"""Given structure information, create a Graph structure to be used to
204+
create Crystal_Graph using C implementation.
205+
206+
NOTE: this is the fast version of _create_graph_legacy optimized
207+
in c (~3x speedup).
208+
209+
Args:
210+
n_atoms (int): the number of atoms in the structure
211+
center_index (np.ndarray): np array of indices of center atoms.
212+
[num_undirected_bonds]
213+
neighbor_index (np.ndarray): np array of indices of neighbor atoms.
214+
[num_undirected_bonds]
215+
image (np.ndarray): np array of images for each edge.
216+
[num_undirected_bonds, 3]
217+
distance (np.ndarray): np array of distances.
218+
[num_undirected_bonds]
219+
220+
Return:
221+
Graph data structure used to create Crystal_Graph object
222+
"""
223+
center_index = np.ascontiguousarray(center_index)
224+
neighbor_index = np.ascontiguousarray(neighbor_index)
225+
image = np.ascontiguousarray(image, dtype=np.int64)
226+
distance = np.ascontiguousarray(distance)
227+
gc_saved = gc.get_threshold()
228+
gc.set_threshold(0)
229+
nodes, dir_edges_list, undir_edges_list, undirected_edges = make_graph(
230+
center_index, len(center_index), neighbor_index, image, distance, n_atoms
231+
)
232+
graph = Graph(nodes=nodes)
233+
graph.directed_edges_list = dir_edges_list
234+
graph.undirected_edges_list = undir_edges_list
235+
graph.undirected_edges = undirected_edges
236+
gc.set_threshold(gc_saved[0])
237+
return graph
238+
239+
def set_isolated_atom_response(
240+
self, on_isolated_atoms: Literal["ignore", "warn", "error"]
241+
) -> None:
242+
"""Set the graph converter's response to isolated atom graph
243+
Args:
244+
on_isolated_atoms ('ignore' | 'warn' | 'error'): how to handle Structures
245+
with isolated atoms.
246+
Default = 'error'.
247+
248+
Returns:
249+
None
250+
"""
251+
self.on_isolated_atoms = on_isolated_atoms
252+
253+
def as_dict(self) -> dict[str, str | float]:
254+
"""Save the args of the graph converter."""
255+
return {
256+
"atom_graph_cutoff": self.atom_graph_cutoff,
257+
"bond_graph_cutoff": self.bond_graph_cutoff,
258+
"algorithm": self.algorithm,
259+
}
260+
261+
@classmethod
262+
def from_dict(cls, dct: dict) -> Self:
263+
"""Create converter from dictionary."""
264+
return cls(**dct)

0 commit comments

Comments
 (0)