@@ -96,6 +96,7 @@ def ase_atoms_to_atom_graphs(
9696 ),
9797 system_id : Optional [int ] = None ,
9898 brute_force_knn : Optional [bool ] = None ,
99+ device : Optional [torch .device ] = torch .device ("cuda" if torch .cuda .is_available () else "cpu" ),
99100) -> AtomGraphs :
100101 """Generate AtomGraphs from an ase.Atoms object.
101102
@@ -107,6 +108,7 @@ def ase_atoms_to_atom_graphs(
107108 Defaults to None, in which case brute_force is used if we a GPU is avaiable (2-6x faster),
108109 but not on CPU (1.5x faster - 4x slower). For very large systems, brute_force may OOM on GPU,
109110 so it is recommended to set to False in that case.
111+ device: device to put the tensors on.
110112
111113 Returns:
112114 AtomGraphs object
@@ -133,7 +135,7 @@ def ase_atoms_to_atom_graphs(
133135 )
134136
135137 num_atoms = len (node_feats ["positions" ]) # type: ignore
136- return AtomGraphs (
138+ atom_graph = AtomGraphs (
137139 senders = senders ,
138140 receivers = receivers ,
139141 n_node = torch .tensor ([num_atoms ]),
@@ -147,6 +149,7 @@ def ase_atoms_to_atom_graphs(
147149 radius = system_config .radius ,
148150 max_num_neighbors = system_config .max_num_neighbors ,
149151 )
152+ return atom_graph .to (device )
150153
151154
152155def _get_edge_feats (
0 commit comments