|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import gc |
| 4 | +import sys |
| 5 | +import warnings |
| 6 | +from typing import TYPE_CHECKING |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import paddle |
| 10 | +from chgnet.graph.crystalgraph import CrystalGraph |
| 11 | +from chgnet.graph.graph import Graph |
| 12 | +from chgnet.graph.graph import Node |
| 13 | + |
| 14 | +if TYPE_CHECKING: |
| 15 | + from typing import Literal |
| 16 | + |
| 17 | + from pymatgen.core import Structure |
| 18 | + from typing_extensions import Self |
| 19 | +# try: |
| 20 | +# from chgnet.graph.cygraph import make_graph |
| 21 | +# except (ImportError, AttributeError): |
| 22 | +# make_graph = None |
| 23 | +make_graph = None |
| 24 | +DTYPE = "float32" |
| 25 | + |
| 26 | + |
| 27 | +class CrystalGraphConverter(paddle.nn.Layer): |
| 28 | + """Convert a pymatgen.core.Structure to a CrystalGraph |
| 29 | + The CrystalGraph dataclass stores essential field to make sure that |
| 30 | + gradients like force and stress can be calculated through back-propagation later. |
| 31 | + """ |
| 32 | + |
| 33 | + make_graph = None |
| 34 | + |
| 35 | + def __init__( |
| 36 | + self, |
| 37 | + *, |
| 38 | + atom_graph_cutoff: float = 6, |
| 39 | + bond_graph_cutoff: float = 3, |
| 40 | + algorithm: Literal["legacy", "fast"] = "fast", |
| 41 | + on_isolated_atoms: Literal["ignore", "warn", "error"] = "error", |
| 42 | + verbose: bool = False, |
| 43 | + ) -> None: |
| 44 | + """Initialize the Crystal Graph Converter. |
| 45 | +
|
| 46 | + Args: |
| 47 | + atom_graph_cutoff (float): cutoff radius to search for neighboring atom in |
| 48 | + atom_graph. Default = 5. |
| 49 | + bond_graph_cutoff (float): bond length threshold to include bond in |
| 50 | + bond_graph. Default = 3. |
| 51 | + algorithm ('legacy' | 'fast'): algorithm to use for converting graphs. |
| 52 | + 'legacy': python implementation of graph creation |
| 53 | + 'fast': C implementation of graph creation, this is faster, |
| 54 | + but will need the cygraph.c file correctly compiled from pip install |
| 55 | + Default = 'fast' |
| 56 | + on_isolated_atoms ('ignore' | 'warn' | 'error'): how to handle Structures |
| 57 | + with isolated atoms. |
| 58 | + Default = 'error' |
| 59 | + verbose (bool): whether to print the CrystalGraphConverter |
| 60 | + initialization message. Default = False. |
| 61 | + """ |
| 62 | + super().__init__() |
| 63 | + self.atom_graph_cutoff = atom_graph_cutoff |
| 64 | + self.bond_graph_cutoff = ( |
| 65 | + atom_graph_cutoff if bond_graph_cutoff is None else bond_graph_cutoff |
| 66 | + ) |
| 67 | + self.on_isolated_atoms = on_isolated_atoms |
| 68 | + self.create_graph = self._create_graph_legacy |
| 69 | + self.algorithm = "legacy" |
| 70 | + if algorithm == "fast": |
| 71 | + if make_graph is not None: |
| 72 | + self.create_graph = self._create_graph_fast |
| 73 | + self.algorithm = "fast" |
| 74 | + else: |
| 75 | + warnings.warn( |
| 76 | + "`fast` algorithm is not available, using `legacy`", |
| 77 | + UserWarning, |
| 78 | + stacklevel=1, |
| 79 | + ) |
| 80 | + elif algorithm != "legacy": |
| 81 | + warnings.warn( |
| 82 | + f"Unknown algorithm={algorithm!r}, using `legacy`", |
| 83 | + UserWarning, |
| 84 | + stacklevel=1, |
| 85 | + ) |
| 86 | + if verbose: |
| 87 | + print(self) |
| 88 | + |
| 89 | + def __repr__(self) -> str: |
| 90 | + """String representation of the CrystalGraphConverter.""" |
| 91 | + atom_graph_cutoff = self.atom_graph_cutoff |
| 92 | + bond_graph_cutoff = self.bond_graph_cutoff |
| 93 | + algorithm = self.algorithm |
| 94 | + cls_name = type(self).__name__ |
| 95 | + return f"{cls_name}(algorithm={algorithm!r}, atom_graph_cutoff={atom_graph_cutoff!r}, bond_graph_cutoff={bond_graph_cutoff!r})" |
| 96 | + |
| 97 | + def forward(self, structure: Structure, graph_id=None, mp_id=None) -> CrystalGraph: |
| 98 | + """Convert a structure, return a CrystalGraph. |
| 99 | +
|
| 100 | + Args: |
| 101 | + structure (pymatgen.core.Structure): structure to convert |
| 102 | + graph_id (str): an id to keep track of this crystal graph |
| 103 | + Default = None |
| 104 | + mp_id (str): Materials Project id of this structure |
| 105 | + Default = None |
| 106 | +
|
| 107 | + Return: |
| 108 | + CrystalGraph that is ready to use by CHGNet |
| 109 | + """ |
| 110 | + n_atoms = len(structure) |
| 111 | + data = [site.specie.Z for site in structure] |
| 112 | + atomic_number = paddle.to_tensor(data, dtype="int32", stop_gradient=not False) |
| 113 | + atom_frac_coord = paddle.to_tensor( |
| 114 | + data=structure.frac_coords, dtype=DTYPE, stop_gradient=not True |
| 115 | + ) |
| 116 | + lattice = paddle.to_tensor( |
| 117 | + data=structure.lattice.matrix, dtype=DTYPE, stop_gradient=not True |
| 118 | + ) |
| 119 | + center_index, neighbor_index, image, distance = structure.get_neighbor_list( |
| 120 | + r=self.atom_graph_cutoff, sites=structure.sites, numerical_tol=1e-08 |
| 121 | + ) |
| 122 | + graph = self.create_graph( |
| 123 | + n_atoms, center_index, neighbor_index, image, distance |
| 124 | + ) |
| 125 | + atom_graph, directed2undirected = graph.adjacency_list() |
| 126 | + atom_graph = paddle.to_tensor(data=atom_graph, dtype="int32") |
| 127 | + directed2undirected = paddle.to_tensor(data=directed2undirected, dtype="int32") |
| 128 | + try: |
| 129 | + bond_graph, undirected2directed = graph.line_graph_adjacency_list( |
| 130 | + cutoff=self.bond_graph_cutoff |
| 131 | + ) |
| 132 | + except Exception as exc: |
| 133 | + structure.to(filename="bond_graph_error.cif") |
| 134 | + raise RuntimeError( |
| 135 | + f"Failed creating bond graph for {graph_id}, check bond_graph_error.cif" |
| 136 | + ) from exc |
| 137 | + bond_graph = paddle.to_tensor(data=bond_graph, dtype="int32") |
| 138 | + undirected2directed = paddle.to_tensor(data=undirected2directed, dtype="int32") |
| 139 | + n_isolated_atoms = len({*range(n_atoms)} - {*center_index}) |
| 140 | + if n_isolated_atoms: |
| 141 | + atom_graph_cutoff = self.atom_graph_cutoff |
| 142 | + msg = f"Structure graph_id={graph_id!r} has {n_isolated_atoms} isolated atom(s) with atom_graph_cutoff={atom_graph_cutoff!r}. CHGNet calculation will likely go wrong" |
| 143 | + if self.on_isolated_atoms == "error": |
| 144 | + raise ValueError(msg) |
| 145 | + elif self.on_isolated_atoms == "warn": |
| 146 | + print(msg, file=sys.stderr) |
| 147 | + return CrystalGraph( |
| 148 | + atomic_number=atomic_number, |
| 149 | + atom_frac_coord=atom_frac_coord, |
| 150 | + atom_graph=atom_graph, |
| 151 | + neighbor_image=paddle.to_tensor(data=image, dtype=DTYPE), |
| 152 | + directed2undirected=directed2undirected, |
| 153 | + undirected2directed=undirected2directed, |
| 154 | + bond_graph=bond_graph, |
| 155 | + lattice=lattice, |
| 156 | + graph_id=graph_id, |
| 157 | + mp_id=mp_id, |
| 158 | + composition=structure.composition.formula, |
| 159 | + atom_graph_cutoff=self.atom_graph_cutoff, |
| 160 | + bond_graph_cutoff=self.bond_graph_cutoff, |
| 161 | + ) |
| 162 | + |
| 163 | + @staticmethod |
| 164 | + def _create_graph_legacy( |
| 165 | + n_atoms: int, |
| 166 | + center_index: np.ndarray, |
| 167 | + neighbor_index: np.ndarray, |
| 168 | + image: np.ndarray, |
| 169 | + distance: np.ndarray, |
| 170 | + ) -> Graph: |
| 171 | + """Given structure information, create a Graph structure to be used to |
| 172 | + create Crystal_Graph using pure python implementation. |
| 173 | +
|
| 174 | + Args: |
| 175 | + n_atoms (int): the number of atoms in the structure |
| 176 | + center_index (np.ndarray): np array of indices of center atoms. |
| 177 | + [num_undirected_bonds] |
| 178 | + neighbor_index (np.ndarray): np array of indices of neighbor atoms. |
| 179 | + [num_undirected_bonds] |
| 180 | + image (np.ndarray): np array of images for each edge. |
| 181 | + [num_undirected_bonds, 3] |
| 182 | + distance (np.ndarray): np array of distances. |
| 183 | + [num_undirected_bonds] |
| 184 | +
|
| 185 | + Return: |
| 186 | + Graph data structure used to create Crystal_Graph object |
| 187 | + """ |
| 188 | + graph = Graph([Node(index=idx) for idx in range(n_atoms)]) |
| 189 | + for ii, jj, img, dist in zip( |
| 190 | + center_index, neighbor_index, image, distance, strict=True |
| 191 | + ): |
| 192 | + graph.add_edge(center_index=ii, neighbor_index=jj, image=img, distance=dist) |
| 193 | + return graph |
| 194 | + |
| 195 | + @staticmethod |
| 196 | + def _create_graph_fast( |
| 197 | + n_atoms: int, |
| 198 | + center_index: np.ndarray, |
| 199 | + neighbor_index: np.ndarray, |
| 200 | + image: np.ndarray, |
| 201 | + distance: np.ndarray, |
| 202 | + ) -> Graph: |
| 203 | + """Given structure information, create a Graph structure to be used to |
| 204 | + create Crystal_Graph using C implementation. |
| 205 | +
|
| 206 | + NOTE: this is the fast version of _create_graph_legacy optimized |
| 207 | + in c (~3x speedup). |
| 208 | +
|
| 209 | + Args: |
| 210 | + n_atoms (int): the number of atoms in the structure |
| 211 | + center_index (np.ndarray): np array of indices of center atoms. |
| 212 | + [num_undirected_bonds] |
| 213 | + neighbor_index (np.ndarray): np array of indices of neighbor atoms. |
| 214 | + [num_undirected_bonds] |
| 215 | + image (np.ndarray): np array of images for each edge. |
| 216 | + [num_undirected_bonds, 3] |
| 217 | + distance (np.ndarray): np array of distances. |
| 218 | + [num_undirected_bonds] |
| 219 | +
|
| 220 | + Return: |
| 221 | + Graph data structure used to create Crystal_Graph object |
| 222 | + """ |
| 223 | + center_index = np.ascontiguousarray(center_index) |
| 224 | + neighbor_index = np.ascontiguousarray(neighbor_index) |
| 225 | + image = np.ascontiguousarray(image, dtype=np.int64) |
| 226 | + distance = np.ascontiguousarray(distance) |
| 227 | + gc_saved = gc.get_threshold() |
| 228 | + gc.set_threshold(0) |
| 229 | + nodes, dir_edges_list, undir_edges_list, undirected_edges = make_graph( |
| 230 | + center_index, len(center_index), neighbor_index, image, distance, n_atoms |
| 231 | + ) |
| 232 | + graph = Graph(nodes=nodes) |
| 233 | + graph.directed_edges_list = dir_edges_list |
| 234 | + graph.undirected_edges_list = undir_edges_list |
| 235 | + graph.undirected_edges = undirected_edges |
| 236 | + gc.set_threshold(gc_saved[0]) |
| 237 | + return graph |
| 238 | + |
| 239 | + def set_isolated_atom_response( |
| 240 | + self, on_isolated_atoms: Literal["ignore", "warn", "error"] |
| 241 | + ) -> None: |
| 242 | + """Set the graph converter's response to isolated atom graph |
| 243 | + Args: |
| 244 | + on_isolated_atoms ('ignore' | 'warn' | 'error'): how to handle Structures |
| 245 | + with isolated atoms. |
| 246 | + Default = 'error'. |
| 247 | +
|
| 248 | + Returns: |
| 249 | + None |
| 250 | + """ |
| 251 | + self.on_isolated_atoms = on_isolated_atoms |
| 252 | + |
| 253 | + def as_dict(self) -> dict[str, str | float]: |
| 254 | + """Save the args of the graph converter.""" |
| 255 | + return { |
| 256 | + "atom_graph_cutoff": self.atom_graph_cutoff, |
| 257 | + "bond_graph_cutoff": self.bond_graph_cutoff, |
| 258 | + "algorithm": self.algorithm, |
| 259 | + } |
| 260 | + |
| 261 | + @classmethod |
| 262 | + def from_dict(cls, dct: dict) -> Self: |
| 263 | + """Create converter from dictionary.""" |
| 264 | + return cls(**dct) |
0 commit comments