Skip to content

Commit 04efdd2

Browse files
committed
Add a prototype of the heat flux wrapper
1 parent 173eb24 commit 04efdd2

File tree

2 files changed

+313
-9
lines changed

2 files changed

+313
-9
lines changed

python/metatomic_torch/metatomic/torch/ase_calculator.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@
9090
},
9191
}
9292

93+
IMPLEMENTED_PROPERTIES = [
94+
"heat_flux",
95+
]
96+
9397

9498
class MetatomicCalculator(ase.calculators.calculator.Calculator):
9599
"""
@@ -284,9 +288,9 @@ def __init__(
284288
for name, output in additional_outputs.items():
285289
assert isinstance(name, str)
286290
assert isinstance(output, torch.ScriptObject)
287-
assert "explicit_gradients_setter" in output._method_names(), (
288-
"outputs must be ModelOutput instances"
289-
)
291+
assert (
292+
"explicit_gradients_setter" in output._method_names()
293+
), "outputs must be ModelOutput instances"
290294

291295
self._additional_output_requests = additional_outputs
292296

@@ -309,7 +313,7 @@ def __init__(
309313

310314
# We do our own check to verify if a property is implemented in `calculate()`,
311315
# so we pretend to be able to compute all properties ASE knows about.
312-
self.implemented_properties = ALL_ASE_PROPERTIES
316+
self.implemented_properties = ALL_ASE_PROPERTIES + IMPLEMENTED_PROPERTIES
313317

314318
self.additional_outputs: Dict[str, TensorMap] = {}
315319
"""
@@ -1002,9 +1006,11 @@ def _get_ase_input(
10021006
[torch.full((values.shape[0],), 0), torch.arange(values.shape[0])]
10031007
).T,
10041008
),
1005-
components=[Labels(["xyz"], torch.arange(values.shape[1]).reshape(-1, 1))]
1006-
if values.shape[1] != 1
1007-
else [],
1009+
components=(
1010+
[Labels(["xyz"], torch.arange(values.shape[1]).reshape(-1, 1))]
1011+
if values.shape[1] != 1
1012+
else []
1013+
),
10081014
properties=Labels(
10091015
[
10101016
name if "::" not in name else name.split("::")[1],
@@ -1018,8 +1024,8 @@ def _get_ase_input(
10181024
)
10191025
tmap.set_info("quantity", option.quantity)
10201026
tmap.set_info("unit", option.unit)
1021-
tmap.to(dtype=dtype, device=device)
1022-
return tmap
1027+
1028+
return tmap.to(dtype=dtype, device=device)
10231029

10241030

10251031
def _ase_to_torch_data(atoms, dtype, device):
Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
import torch
2+
3+
from torch.autograd.functional import jvp
4+
from typing import List, Dict, Optional
5+
from vesin.metatomic import compute_requested_neighbors
6+
7+
8+
from metatensor.torch import Labels, TensorBlock, TensorMap
9+
from metatomic.torch import (
10+
AtomisticModel,
11+
ModelEvaluationOptions,
12+
ModelOutput,
13+
System,
14+
)
15+
16+
17+
def wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor:
18+
fractional_positions = torch.einsum("iv,kv->ik", positions, cell.inverse())
19+
fractional_positions -= torch.floor(fractional_positions)
20+
wrapped_positions = torch.einsum("iv,kv->ik", fractional_positions, cell)
21+
22+
return wrapped_positions
23+
24+
25+
def check_collisions(
26+
cell: torch.Tensor, positions: torch.Tensor, cutoff: float
27+
) -> tuple[torch.Tensor, torch.Tensor]:
28+
inv_cell = cell.inverse()
29+
norm_inv_cell = torch.linalg.norm(inv_cell, dim=1)
30+
inv_cell /= norm_inv_cell[:, None]
31+
norm_coords = torch.einsum("iv,kv->ik", positions, inv_cell)
32+
cell_vec_lengths = torch.diag(cell @ inv_cell)
33+
collisions = torch.hstack(
34+
[norm_coords <= cutoff, norm_coords >= cell_vec_lengths - cutoff],
35+
).to(device=positions.device)
36+
37+
return collisions[:, [0, 3, 1, 4, 2, 5]], norm_coords
38+
39+
40+
def collisions_to_replicas(collisions: torch.Tensor) -> torch.Tensor:
41+
"""
42+
Convert collisions to replicas.
43+
44+
collisions: [N, 6]: has collisions with (x_lo, x_hi, y_lo, y_hi, z_lo, z_hi)
45+
"""
46+
origin = torch.full(
47+
(len(collisions),), True, dtype=torch.bool, device=collisions.device
48+
)
49+
axs = torch.vstack([origin, collisions[:, 0], collisions[:, 1]])
50+
ays = torch.vstack([origin, collisions[:, 2], collisions[:, 3]])
51+
azs = torch.vstack([origin, collisions[:, 4], collisions[:, 5]])
52+
# leverage broadcasting
53+
outs = axs[:, None, None] & ays[None, :, None] & azs[None, None, :]
54+
outs = torch.movedim(outs, -1, 0)
55+
outs[:, 0, 0, 0] = False
56+
return outs.to(device=collisions.device)
57+
58+
59+
def generate_replica_atoms(
60+
types: torch.Tensor,
61+
positions: torch.Tensor,
62+
cell: torch.Tensor,
63+
replicas: torch.Tensor,
64+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
65+
replicas = torch.argwhere(replicas)
66+
replica_idx = replicas[:, 0]
67+
replica_offsets = torch.tensor(
68+
[0, 1, -1], device=positions.device, dtype=positions.dtype
69+
)[replicas[:, 1:]]
70+
replica_positions = positions[replica_idx]
71+
replica_positions += torch.einsum("aA,iA->ia", cell, replica_offsets)
72+
73+
return replica_idx, types[replica_idx], replica_positions
74+
75+
76+
def unfold_system(metatomic_system: System, cutoff: float) -> System:
77+
wrapped_positions = wrap_positions(
78+
metatomic_system.positions, metatomic_system.cell
79+
)
80+
collisions, _ = check_collisions(
81+
metatomic_system.cell, wrapped_positions, cutoff + 0.5
82+
)
83+
replicas = collisions_to_replicas(collisions)
84+
replica_idx, replica_types, replica_positions = generate_replica_atoms(
85+
metatomic_system.types, wrapped_positions, metatomic_system.cell, replicas
86+
)
87+
unfolded_types = torch.cat(
88+
[
89+
metatomic_system.types,
90+
replica_types,
91+
]
92+
)
93+
unfolded_positions = torch.cat(
94+
[
95+
wrapped_positions,
96+
replica_positions,
97+
]
98+
)
99+
unfolded_idx = torch.cat(
100+
[
101+
torch.arange(len(metatomic_system.types), device=metatomic_system.device),
102+
replica_idx,
103+
]
104+
)
105+
unfolded_n_atoms = len(unfolded_types)
106+
masses_block = metatomic_system.get_data("masses").block()
107+
velocities_block = metatomic_system.get_data("velocities").block()
108+
unfolded_masses = masses_block.values[unfolded_idx]
109+
unfolded_velocities = velocities_block.values[unfolded_idx]
110+
unfolded_masses_block = TensorBlock(
111+
values=unfolded_masses,
112+
samples=Labels(
113+
["atoms"],
114+
torch.arange(unfolded_n_atoms, device=metatomic_system.device).reshape(
115+
-1, 1
116+
),
117+
),
118+
components=masses_block.components,
119+
properties=masses_block.properties,
120+
)
121+
unfolded_velocities_block = TensorBlock(
122+
values=unfolded_velocities,
123+
samples=Labels(
124+
["atoms"],
125+
torch.arange(unfolded_n_atoms, device=metatomic_system.device).reshape(
126+
-1, 1
127+
),
128+
),
129+
components=velocities_block.components,
130+
properties=velocities_block.properties,
131+
)
132+
unfolded_system = System(
133+
types=unfolded_types,
134+
positions=unfolded_positions,
135+
cell=torch.tensor(
136+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
137+
dtype=unfolded_positions.dtype,
138+
device=metatomic_system.device,
139+
),
140+
pbc=torch.tensor([False, False, False], device=metatomic_system.device),
141+
)
142+
unfolded_system.add_data(
143+
"masses",
144+
TensorMap(
145+
Labels("_", torch.tensor([[0]], device=metatomic_system.device)),
146+
[unfolded_masses_block],
147+
),
148+
)
149+
unfolded_system.add_data(
150+
"velocities",
151+
TensorMap(
152+
Labels("_", torch.tensor([[0]], device=metatomic_system.device)),
153+
[unfolded_velocities_block],
154+
),
155+
)
156+
return unfolded_system.to(metatomic_system.dtype, metatomic_system.device)
157+
158+
159+
class HeatFluxWrapper(torch.nn.Module):
160+
161+
def __init__(self, model: AtomisticModel):
162+
super().__init__()
163+
164+
self._model = model
165+
# TODO: throw error if the simulation cell is smaller than double the interaction range
166+
self._interaction_range = model.capabilities().interaction_range
167+
168+
self._requested_inputs = {
169+
"masses": ModelOutput(quantity="mass", unit="u", per_atom=True),
170+
"velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True),
171+
}
172+
173+
hf_output = ModelOutput(
174+
quantity="heat_flux",
175+
unit="",
176+
explicit_gradients=[],
177+
per_atom=False,
178+
)
179+
outputs = self._model.capabilities().outputs.copy()
180+
outputs["extra::heat_flux"] = hf_output
181+
self._model.capabilities().outputs["extra::heat_flux"] = hf_output
182+
183+
energies_output = ModelOutput(
184+
quantity="energy", unit=outputs["energy"].unit, per_atom=True
185+
)
186+
self._unfolded_run_options = ModelEvaluationOptions(
187+
length_unit=self._model.capabilities().length_unit,
188+
outputs={"energy": energies_output},
189+
selected_atoms=None,
190+
)
191+
192+
def requested_inputs(self) -> Dict[str, ModelOutput]:
193+
return self._requested_inputs
194+
195+
def barycenter_and_atomic_energies(self, system: System, n_atoms: int):
196+
atomic_e = self._model([system], self._unfolded_run_options, False)["energy"][
197+
0
198+
].values.flatten()
199+
total_e = atomic_e[:n_atoms].sum()
200+
r_aux = system.positions.detach()
201+
barycenter = torch.einsum("i,ik->k", atomic_e[:n_atoms], r_aux[:n_atoms])
202+
203+
return barycenter, atomic_e, total_e
204+
205+
def calc_unfolded_heat_flux(self, system: System) -> torch.Tensor:
206+
n_atoms = len(system.positions)
207+
unfolded_system = unfold_system(system, self._interaction_range).to("cpu")
208+
compute_requested_neighbors(
209+
unfolded_system, self._unfolded_run_options.length_unit, model=self._model
210+
)
211+
unfolded_system = unfolded_system.to(system.device)
212+
velocities: torch.Tensor = (
213+
unfolded_system.get_data("velocities").block().values.reshape(-1, 3)
214+
)
215+
masses: torch.Tensor = (
216+
unfolded_system.get_data("masses").block().values.reshape(-1)
217+
)
218+
barycenter, atomic_e, total_e = self.barycenter_and_atomic_energies(
219+
unfolded_system, n_atoms
220+
)
221+
222+
term1 = torch.zeros(
223+
(3), device=system.positions.device, dtype=system.positions.dtype
224+
)
225+
for i in range(3):
226+
grad_i = torch.autograd.grad(
227+
[barycenter[i]],
228+
[unfolded_system.positions],
229+
retain_graph=True,
230+
create_graph=False,
231+
)[0]
232+
grad_i = torch.jit._unwrap_optional(grad_i)
233+
term1[i] = (grad_i * velocities).sum()
234+
235+
go = torch.jit.annotate(
236+
Optional[List[Optional[torch.Tensor]]], [torch.ones_like(total_e)]
237+
)
238+
grads = torch.autograd.grad(
239+
[total_e],
240+
[unfolded_system.positions],
241+
grad_outputs=go,
242+
)[0]
243+
grads = torch.jit._unwrap_optional(grads)
244+
term2 = (
245+
unfolded_system.positions * (grads * velocities).sum(dim=1, keepdim=True)
246+
).sum(dim=0)
247+
248+
hf_pot = term1 - term2
249+
250+
hf_conv = (
251+
(
252+
atomic_e[:n_atoms]
253+
+ 0.5
254+
* masses[:n_atoms]
255+
* torch.linalg.norm(velocities[:n_atoms], dim=1) ** 2
256+
* 103.6427 # u*A^2/fs^2 to eV
257+
)[:, None]
258+
* velocities[:n_atoms]
259+
).sum(dim=0)
260+
261+
return hf_pot + hf_conv
262+
263+
def forward(
264+
self,
265+
systems: List[System],
266+
outputs: Dict[str, ModelOutput],
267+
selected_atoms: Optional[Labels],
268+
) -> Dict[str, TensorMap]:
269+
270+
run_options = ModelEvaluationOptions(
271+
length_unit=self._model.capabilities().length_unit,
272+
outputs=outputs,
273+
selected_atoms=None,
274+
)
275+
results = self._model(systems, run_options, False)
276+
277+
if "extra::heat_flux" not in outputs:
278+
return results
279+
280+
device = systems[0].device
281+
heat_fluxes: List[torch.Tensor] = []
282+
for system in systems:
283+
heat_fluxes.append(self.calc_unfolded_heat_flux(system))
284+
285+
samples = Labels(
286+
["system"], torch.arange(len(systems), device=device).reshape(-1, 1)
287+
)
288+
289+
hf_block = TensorBlock(
290+
values=torch.vstack(heat_fluxes).reshape(-1, 3, 1).to(device=device),
291+
samples=samples,
292+
components=[Labels(["xyz"], torch.arange(3, device=device).reshape(-1, 1))],
293+
properties=Labels(["heat_flux"], torch.tensor([[0]], device=device)),
294+
)
295+
results["extra::heat_flux"] = TensorMap(
296+
Labels("_", torch.tensor([[0]], device=device)), [hf_block]
297+
)
298+
return results

0 commit comments

Comments
 (0)