-
Notifications
You must be signed in to change notification settings - Fork 6
Add a heat flux wrapper #144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
GardevoirX
wants to merge
1
commit into
metatensor:main
Choose a base branch
from
GardevoirX:heat-flux-wrapper
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+313
−9
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,298 @@ | ||
| import torch | ||
|
|
||
| from torch.autograd.functional import jvp | ||
| from typing import List, Dict, Optional | ||
| from vesin.metatomic import compute_requested_neighbors | ||
|
|
||
|
|
||
| from metatensor.torch import Labels, TensorBlock, TensorMap | ||
| from metatomic.torch import ( | ||
| AtomisticModel, | ||
| ModelEvaluationOptions, | ||
| ModelOutput, | ||
| System, | ||
| ) | ||
|
|
||
|
|
||
| def wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: | ||
| fractional_positions = torch.einsum("iv,kv->ik", positions, cell.inverse()) | ||
| fractional_positions -= torch.floor(fractional_positions) | ||
| wrapped_positions = torch.einsum("iv,kv->ik", fractional_positions, cell) | ||
|
|
||
| return wrapped_positions | ||
|
|
||
|
|
||
| def check_collisions( | ||
| cell: torch.Tensor, positions: torch.Tensor, cutoff: float | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| inv_cell = cell.inverse() | ||
| norm_inv_cell = torch.linalg.norm(inv_cell, dim=1) | ||
| inv_cell /= norm_inv_cell[:, None] | ||
| norm_coords = torch.einsum("iv,kv->ik", positions, inv_cell) | ||
| cell_vec_lengths = torch.diag(cell @ inv_cell) | ||
| collisions = torch.hstack( | ||
| [norm_coords <= cutoff, norm_coords >= cell_vec_lengths - cutoff], | ||
| ).to(device=positions.device) | ||
|
|
||
| return collisions[:, [0, 3, 1, 4, 2, 5]], norm_coords | ||
|
|
||
|
|
||
| def collisions_to_replicas(collisions: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Convert collisions to replicas. | ||
|
|
||
| collisions: [N, 6]: has collisions with (x_lo, x_hi, y_lo, y_hi, z_lo, z_hi) | ||
| """ | ||
| origin = torch.full( | ||
| (len(collisions),), True, dtype=torch.bool, device=collisions.device | ||
| ) | ||
| axs = torch.vstack([origin, collisions[:, 0], collisions[:, 1]]) | ||
| ays = torch.vstack([origin, collisions[:, 2], collisions[:, 3]]) | ||
| azs = torch.vstack([origin, collisions[:, 4], collisions[:, 5]]) | ||
| # leverage broadcasting | ||
| outs = axs[:, None, None] & ays[None, :, None] & azs[None, None, :] | ||
| outs = torch.movedim(outs, -1, 0) | ||
| outs[:, 0, 0, 0] = False | ||
| return outs.to(device=collisions.device) | ||
|
|
||
|
|
||
| def generate_replica_atoms( | ||
| types: torch.Tensor, | ||
| positions: torch.Tensor, | ||
| cell: torch.Tensor, | ||
| replicas: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| replicas = torch.argwhere(replicas) | ||
| replica_idx = replicas[:, 0] | ||
| replica_offsets = torch.tensor( | ||
| [0, 1, -1], device=positions.device, dtype=positions.dtype | ||
| )[replicas[:, 1:]] | ||
| replica_positions = positions[replica_idx] | ||
| replica_positions += torch.einsum("aA,iA->ia", cell, replica_offsets) | ||
|
|
||
| return replica_idx, types[replica_idx], replica_positions | ||
|
|
||
|
|
||
| def unfold_system(metatomic_system: System, cutoff: float) -> System: | ||
| wrapped_positions = wrap_positions( | ||
| metatomic_system.positions, metatomic_system.cell | ||
| ) | ||
| collisions, _ = check_collisions( | ||
| metatomic_system.cell, wrapped_positions, cutoff + 0.5 | ||
| ) | ||
| replicas = collisions_to_replicas(collisions) | ||
| replica_idx, replica_types, replica_positions = generate_replica_atoms( | ||
| metatomic_system.types, wrapped_positions, metatomic_system.cell, replicas | ||
| ) | ||
| unfolded_types = torch.cat( | ||
| [ | ||
| metatomic_system.types, | ||
| replica_types, | ||
| ] | ||
| ) | ||
| unfolded_positions = torch.cat( | ||
| [ | ||
| wrapped_positions, | ||
| replica_positions, | ||
| ] | ||
| ) | ||
| unfolded_idx = torch.cat( | ||
| [ | ||
| torch.arange(len(metatomic_system.types), device=metatomic_system.device), | ||
| replica_idx, | ||
| ] | ||
| ) | ||
| unfolded_n_atoms = len(unfolded_types) | ||
| masses_block = metatomic_system.get_data("masses").block() | ||
| velocities_block = metatomic_system.get_data("velocities").block() | ||
| unfolded_masses = masses_block.values[unfolded_idx] | ||
| unfolded_velocities = velocities_block.values[unfolded_idx] | ||
| unfolded_masses_block = TensorBlock( | ||
| values=unfolded_masses, | ||
| samples=Labels( | ||
| ["atoms"], | ||
| torch.arange(unfolded_n_atoms, device=metatomic_system.device).reshape( | ||
| -1, 1 | ||
| ), | ||
| ), | ||
| components=masses_block.components, | ||
| properties=masses_block.properties, | ||
| ) | ||
| unfolded_velocities_block = TensorBlock( | ||
| values=unfolded_velocities, | ||
| samples=Labels( | ||
| ["atoms"], | ||
| torch.arange(unfolded_n_atoms, device=metatomic_system.device).reshape( | ||
| -1, 1 | ||
| ), | ||
| ), | ||
| components=velocities_block.components, | ||
| properties=velocities_block.properties, | ||
| ) | ||
| unfolded_system = System( | ||
| types=unfolded_types, | ||
| positions=unfolded_positions, | ||
| cell=torch.tensor( | ||
| [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], | ||
| dtype=unfolded_positions.dtype, | ||
| device=metatomic_system.device, | ||
| ), | ||
| pbc=torch.tensor([False, False, False], device=metatomic_system.device), | ||
| ) | ||
| unfolded_system.add_data( | ||
| "masses", | ||
| TensorMap( | ||
| Labels("_", torch.tensor([[0]], device=metatomic_system.device)), | ||
| [unfolded_masses_block], | ||
| ), | ||
| ) | ||
| unfolded_system.add_data( | ||
| "velocities", | ||
| TensorMap( | ||
| Labels("_", torch.tensor([[0]], device=metatomic_system.device)), | ||
| [unfolded_velocities_block], | ||
| ), | ||
| ) | ||
| return unfolded_system.to(metatomic_system.dtype, metatomic_system.device) | ||
|
|
||
|
|
||
| class HeatFluxWrapper(torch.nn.Module): | ||
|
|
||
| def __init__(self, model: AtomisticModel): | ||
| super().__init__() | ||
|
|
||
| self._model = model | ||
| # TODO: throw error if the simulation cell is smaller than double the interaction range | ||
| self._interaction_range = model.capabilities().interaction_range | ||
|
|
||
| self._requested_inputs = { | ||
| "masses": ModelOutput(quantity="mass", unit="u", per_atom=True), | ||
| "velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True), | ||
| } | ||
|
|
||
| hf_output = ModelOutput( | ||
| quantity="heat_flux", | ||
| unit="", | ||
| explicit_gradients=[], | ||
| per_atom=False, | ||
| ) | ||
| outputs = self._model.capabilities().outputs.copy() | ||
| outputs["extra::heat_flux"] = hf_output | ||
| self._model.capabilities().outputs["extra::heat_flux"] = hf_output | ||
|
|
||
| energies_output = ModelOutput( | ||
| quantity="energy", unit=outputs["energy"].unit, per_atom=True | ||
| ) | ||
| self._unfolded_run_options = ModelEvaluationOptions( | ||
| length_unit=self._model.capabilities().length_unit, | ||
| outputs={"energy": energies_output}, | ||
| selected_atoms=None, | ||
| ) | ||
|
|
||
| def requested_inputs(self) -> Dict[str, ModelOutput]: | ||
| return self._requested_inputs | ||
|
|
||
| def barycenter_and_atomic_energies(self, system: System, n_atoms: int): | ||
| atomic_e = self._model([system], self._unfolded_run_options, False)["energy"][ | ||
| 0 | ||
| ].values.flatten() | ||
| total_e = atomic_e[:n_atoms].sum() | ||
| r_aux = system.positions.detach() | ||
| barycenter = torch.einsum("i,ik->k", atomic_e[:n_atoms], r_aux[:n_atoms]) | ||
|
|
||
| return barycenter, atomic_e, total_e | ||
|
|
||
| def calc_unfolded_heat_flux(self, system: System) -> torch.Tensor: | ||
| n_atoms = len(system.positions) | ||
| unfolded_system = unfold_system(system, self._interaction_range).to("cpu") | ||
| compute_requested_neighbors( | ||
| unfolded_system, self._unfolded_run_options.length_unit, model=self._model | ||
| ) | ||
| unfolded_system = unfolded_system.to(system.device) | ||
| velocities: torch.Tensor = ( | ||
| unfolded_system.get_data("velocities").block().values.reshape(-1, 3) | ||
| ) | ||
| masses: torch.Tensor = ( | ||
| unfolded_system.get_data("masses").block().values.reshape(-1) | ||
| ) | ||
| barycenter, atomic_e, total_e = self.barycenter_and_atomic_energies( | ||
| unfolded_system, n_atoms | ||
| ) | ||
|
|
||
| term1 = torch.zeros( | ||
| (3), device=system.positions.device, dtype=system.positions.dtype | ||
| ) | ||
| for i in range(3): | ||
| grad_i = torch.autograd.grad( | ||
| [barycenter[i]], | ||
| [unfolded_system.positions], | ||
| retain_graph=True, | ||
| create_graph=False, | ||
| )[0] | ||
| grad_i = torch.jit._unwrap_optional(grad_i) | ||
| term1[i] = (grad_i * velocities).sum() | ||
|
|
||
| go = torch.jit.annotate( | ||
| Optional[List[Optional[torch.Tensor]]], [torch.ones_like(total_e)] | ||
| ) | ||
| grads = torch.autograd.grad( | ||
| [total_e], | ||
| [unfolded_system.positions], | ||
| grad_outputs=go, | ||
| )[0] | ||
| grads = torch.jit._unwrap_optional(grads) | ||
| term2 = ( | ||
| unfolded_system.positions * (grads * velocities).sum(dim=1, keepdim=True) | ||
| ).sum(dim=0) | ||
|
|
||
| hf_pot = term1 - term2 | ||
|
|
||
| hf_conv = ( | ||
| ( | ||
| atomic_e[:n_atoms] | ||
| + 0.5 | ||
| * masses[:n_atoms] | ||
| * torch.linalg.norm(velocities[:n_atoms], dim=1) ** 2 | ||
| * 103.6427 # u*A^2/fs^2 to eV | ||
| )[:, None] | ||
| * velocities[:n_atoms] | ||
| ).sum(dim=0) | ||
|
|
||
| return hf_pot + hf_conv | ||
|
|
||
| def forward( | ||
| self, | ||
| systems: List[System], | ||
| outputs: Dict[str, ModelOutput], | ||
| selected_atoms: Optional[Labels], | ||
| ) -> Dict[str, TensorMap]: | ||
|
|
||
| run_options = ModelEvaluationOptions( | ||
| length_unit=self._model.capabilities().length_unit, | ||
| outputs=outputs, | ||
| selected_atoms=None, | ||
| ) | ||
| results = self._model(systems, run_options, False) | ||
|
|
||
| if "extra::heat_flux" not in outputs: | ||
| return results | ||
|
|
||
| device = systems[0].device | ||
| heat_fluxes: List[torch.Tensor] = [] | ||
| for system in systems: | ||
| heat_fluxes.append(self.calc_unfolded_heat_flux(system)) | ||
|
|
||
| samples = Labels( | ||
| ["system"], torch.arange(len(systems), device=device).reshape(-1, 1) | ||
| ) | ||
|
|
||
| hf_block = TensorBlock( | ||
| values=torch.vstack(heat_fluxes).reshape(-1, 3, 1).to(device=device), | ||
| samples=samples, | ||
| components=[Labels(["xyz"], torch.arange(3, device=device).reshape(-1, 1))], | ||
| properties=Labels(["heat_flux"], torch.tensor([[0]], device=device)), | ||
| ) | ||
| results["extra::heat_flux"] = TensorMap( | ||
| Labels("_", torch.tensor([[0]], device=device)), [hf_block] | ||
| ) | ||
| return results | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep in mind that long-range models will have an infinite interaction range
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure I will put it in the documentation that this method only supports local/semi-local models