11import itertools
2- import json
32from collections .abc import Sequence
43from functools import cache
54from typing import Any
1211from torch .utils .data import Dataset
1312from tqdm import tqdm
1413
15- from aviary import PKG_DIR
16-
1714
1815class CrystalGraphData (Dataset ):
1916 """Dataset class for the CGCNN structure model."""
@@ -22,58 +19,32 @@ def __init__(
2219 self ,
2320 df : pd .DataFrame ,
2421 task_dict : dict [str , str ],
25- elem_embedding : str = "cgcnn92" ,
2622 structure_col : str = "structure" ,
2723 identifiers : Sequence [str ] = (),
28- radius : float = 5 ,
24+ radius_cutoff : float = 5 ,
2925 max_num_nbr : int = 12 ,
30- dmin : float = 0 ,
31- step : float = 0.2 ,
3226 ):
3327 """Featurize crystal structures into neighborhood graphs with this data class
3428 for CGCNN.
3529
3630 Args:
3731 df (pd.Dataframe): Pandas dataframe holding input and target values.
3832 task_dict ({target: task}): task dict for multi-task learning
39- elem_embedding (str, optional): One of matscholar200, cgcnn92, megnet16,
40- onehot112 or path to a file with custom element embeddings.
41- Defaults to matscholar200.
4233 structure_col (str, optional): df column holding pymatgen Structure objects
4334 as input.
4435 identifiers (list[str], optional): df columns for distinguishing data
4536 points. Will be copied over into the model's output CSV. Defaults to ().
46- radius (float, optional): Cut-off radius for neighborhood. Defaults to 5.
37+ radius_cutoff (float, optional): Cut-off radius for neighborhood.
38+ Defaults to 5.
4739 max_num_nbr (int, optional): maximum number of neighbors to consider.
4840 Defaults to 12.
49- dmin (float, optional): minimum distance in Gaussian basis. Defaults to 0.
50- step (float, optional): increment size of Gaussian basis. Defaults to 0.2.
5141 """
5242 self .task_dict = task_dict
5343 self .identifiers = list (identifiers )
5444
55- self .radius = radius
45+ self .radius_cutoff = radius_cutoff
5646 self .max_num_nbr = max_num_nbr
5747
58- if elem_embedding in ("matscholar200" , "cgcnn92" , "megnet16" , "onehot112" ):
59- elem_embedding = f"{ PKG_DIR } /embeddings/element/{ elem_embedding } .json"
60-
61- with open (elem_embedding ) as file :
62- self .elem_features = json .load (file )
63-
64- for key , value in self .elem_features .items ():
65- self .elem_features [key ] = np .array (value , dtype = float )
66- if not hasattr (self , "elem_emb_len" ):
67- self .elem_emb_len = len (value )
68- elif self .elem_emb_len != len (value ):
69- raise ValueError (
70- f"Element embedding length mismatch: len({ key } )="
71- f"{ len (value )} , expected { self .elem_emb_len } "
72- )
73-
74- self .gaussian_dist_func = GaussianDistance (dmin = dmin , dmax = radius , step = step )
75- self .nbr_fea_dim = self .gaussian_dist_func .embedding_size
76-
7748 self .df = df
7849 self .structure_col = structure_col
7950
@@ -84,7 +55,7 @@ def __init__(
8455 self .df [structure_col ].items (), total = len (df ), desc = desc , disable = None
8556 ):
8657 self_idx , nbr_idx , _ = get_structure_neighbor_info (
87- struct , radius , max_num_nbr
58+ struct , self . radius_cutoff , self . max_num_nbr
8859 )
8960 material_ids = [idx , * self .df .loc [idx ][self .identifiers ]]
9061 if 0 in (len (self_idx ), len (nbr_idx )):
@@ -140,16 +111,10 @@ def __getitem__(self, idx: int):
140111 material_ids = [self .df .index [idx ], * row [self .identifiers ]]
141112
142113 # atom features for disordered sites
143- site_atoms = [atom .species .as_dict () for atom in struct ]
144- atom_features = np .vstack (
145- [
146- np .sum ([self .elem_features [el ] * amt for el , amt in site .items ()], axis = 0 )
147- for site in site_atoms
148- ]
149- )
114+ atom_features = [atom .specie .Z for atom in struct ]
150115
151116 self_idx , nbr_idx , nbr_dist = get_structure_neighbor_info (
152- struct , self .radius , self .max_num_nbr
117+ struct , self .radius_cutoff , self .max_num_nbr
153118 )
154119
155120 if len (self_idx ) == 0 :
@@ -161,9 +126,7 @@ def __getitem__(self, idx: int):
161126 if set (self_idx ) != set (range (len (struct ))):
162127 raise ValueError (f"At least one atom in { material_ids } is isolated" )
163128
164- nbr_dist = self .gaussian_dist_func .expand (nbr_dist )
165-
166- atom_fea_t = Tensor (atom_features )
129+ atom_fea_t = LongTensor (atom_features )
167130 nbr_dist_t = Tensor (nbr_dist )
168131 self_idx_t = LongTensor (self_idx )
169132 nbr_idx_t = LongTensor (nbr_idx )
@@ -278,27 +241,25 @@ def __init__(
278241 "Max radii below minimum radii + step size - please increase dmax."
279242 )
280243
281- self .filter = np .arange (dmin , dmax + step , step )
244+ self .filter = torch .arange (dmin , dmax + step , step )
282245 self .embedding_size = len (self .filter )
283246
284247 if var is None :
285248 var = step
286249
287250 self .var = var
288251
289- def expand (self , distances : np . ndarray ) -> np . ndarray :
252+ def expand (self , distances : Tensor ) -> Tensor :
290253 """Apply Gaussian distance filter to a numpy distance array.
291254
292255 Args:
293256 distances (ArrayLike): A distance matrix of any shape.
294257
295258 Returns:
296- np.ndarray : Expanded distance matrix with the last dimension of length
259+ Tensor : Expanded distance matrix with the last dimension of length
297260 len(self.filter)
298261 """
299- distances = np .array (distances )
300-
301- return np .exp (- ((distances [..., None ] - self .filter ) ** 2 ) / self .var ** 2 )
262+ return torch .exp (- ((distances [..., None ] - self .filter ) ** 2 ) / self .var ** 2 )
302263
303264
304265def get_structure_neighbor_info (
0 commit comments