Skip to content

Commit 3dda11b

Browse files
Add stress support, update precision defaults, update README (#52)
2 parents b3680ef + f066ab8 commit 3dda11b

File tree

6 files changed

+89
-103
lines changed

6 files changed

+89
-103
lines changed

.github/workflows/tests.yml

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
11
name: Run LAMMPS-Python tests
22

3-
on:
4-
push:
5-
branches:
6-
- main
7-
- develop
8-
9-
pull_request:
10-
branches:
11-
- main
3+
on: [push, pull_request]
124

135
jobs:
146
build:
@@ -17,7 +9,7 @@ jobs:
179
strategy:
1810
matrix:
1911
python-version: [3.9]
20-
torch-version: [1.10.1, 1.11.0]
12+
torch-version: [1.11.0]
2113
nequip-branch: ["main"]
2214

2315
steps:
@@ -46,7 +38,7 @@ jobs:
4638
run: |
4739
mkdir lammps_dir/
4840
cd lammps_dir/
49-
git clone -b stable_29Sep2021_update2 --depth 1 "https://github.com/lammps/lammps"
41+
git clone --depth 1 "https://github.com/lammps/lammps"
5042
cd ..
5143
./patch_lammps.sh lammps_dir/lammps/
5244
cd lammps_dir/lammps/

README.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ This pair style allows you to use NequIP models from the [`nequip`](https://gith
88

99
## Pre-requisites
1010

11-
* PyTorch or LibTorch >= 1.10.0
11+
* PyTorch or LibTorch >= 1.11.0; please note that at present we have only thoroughly tested 1.11 on NVIDIA GPUs (see [#311 for NequIP](https://github.com/mir-group/nequip/discussions/311#discussioncomment-5129513)) and 1.13 on AMD GPUs, but newer 2.x versions *may* also work. With newer versions, setting the environment variable `PYTORCH_JIT_USE_NNC_NOT_NVFUSER=1` sometimes helps.
1212

1313
## Usage in LAMMPS
1414

@@ -18,21 +18,21 @@ pair_coeff * * deployed.pth <type name 1> <type name 2> ...
1818
```
1919
where `deployed.pth` is the filename of your trained, **deployed** model.
2020

21-
The names after the model path `deployed.pth` indicate, in order, the names of the NequIP model's atom types to use for LAMMPS atom types 1, 2, and so on. The number of names given must be equal to the number of atom types in the LAMMPS configuration (not the NequIP model!).
21+
The names after the model path `deployed.pth` indicate, in order, the names of the NequIP model's atom types to use for LAMMPS atom types 1, 2, and so on. The number of names given must be equal to the number of atom types in the LAMMPS configuration (not the NequIP model!).
2222
The given names must be consistent with the names specified in the NequIP training YAML in `chemical_symbol_to_type` or `type_names`.
2323

2424
## Building LAMMPS with this pair style
2525

2626
### Download LAMMPS
2727
```bash
28-
git clone -b stable_29Sep2021_update2 --depth 1 git@github.com:lammps/lammps
28+
git clone --depth=1 https://github.com/lammps/lammps
2929
```
3030
or your preferred method.
31-
(`--depth 1` prevents the entire history of the LAMMPS repository from being downloaded.)
31+
(`--depth=1` prevents the entire history of the LAMMPS repository from being downloaded.)
3232

3333
### Download this repository
3434
```bash
35-
git clone git@github.com:mir-group/pair_nequip
35+
git clone https://github.com/mir-group/pair_nequip
3636
```
3737

3838
### Patch LAMMPS
@@ -49,7 +49,6 @@ cp /path/to/pair_nequip/*.cpp /path/to/lammps/src/
4949
cp /path/to/pair_nequip/*.h /path/to/lammps/src/
5050
```
5151
Then make the following modifications to `lammps/cmake/CMakeLists.txt`:
52-
- Change `set(CMAKE_CXX_STANDARD 11)` to `set(CMAKE_CXX_STANDARD 14)`
5352
- Append the following lines:
5453
```cmake
5554
find_package(Torch REQUIRED)
@@ -106,6 +105,9 @@ This gives `lammps/build/lmp`, which can be run as usual with `/path/to/lmp -in
106105
```
107106

108107
A: Make sure you remembered to deploy (compile) your model using `nequip-deploy`, and that the path to the model given with `pair_coeff` points to a deployed model `.pth` file, **not** a file containing only weights like `best_model.pth`.
109-
3. Q: The output pressures and stresses seem wrong / my NPT simulation is broken
108+
3. Q: I get the following error:
109+
```
110+
Exception: Argument passed to at() was not in the map
111+
```
110112
111-
A: NPT/stress support in LAMMPS for `pair_nequip` is in-progress on the `stress` branch and is not yet finished.
113+
A: We now require models to have been trained with stress support, which is achieved by replacing `ForceOutput` with `StressForceOutput` in the training configuration. Note that you do not need to train on stress (though it may improve your potential, assuming your stress data is correct and converged). If you desperately wish to keep using a model without stress output, there are two options: 1) Remove lines that look like [these](https://github.com/mir-group/pair_allegro/blob/99036043e74376ac52993b5323f193dee3f4f401/pair_allegro_kokkos.cpp#L332-L343) in your version of `pair_allegro[_kokkos].cpp` 2) Redeploy the model with an updated config file, as described [here](https://github.com/mir-group/nequip/issues/69#issuecomment-1129273665).

pair_nequip.cpp

Lines changed: 39 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -38,28 +38,12 @@
3838
#include <torch/torch.h>
3939
#include <torch/script.h>
4040
#include <torch/csrc/jit/runtime/graph_executor.h>
41-
//#include <c10/cuda/CUDACachingAllocator.h>
42-
43-
44-
// We have to do a backward compatability hack for <1.10
45-
// https://discuss.pytorch.org/t/how-to-check-libtorch-version/77709/4
46-
// Basically, the check in torch::jit::freeze
47-
// (see https://github.com/pytorch/pytorch/blob/dfbd030854359207cb3040b864614affeace11ce/torch/csrc/jit/api/module.cpp#L479)
48-
// is wrong, and we have ro "reimplement" the function
49-
// to get around that...
50-
// it's broken in 1.8 and 1.9
51-
// BUT the internal logic in the function is wrong in 1.10
52-
// So we only use torch::jit::freeze in >=1.11
41+
42+
// Freezing is broken from C++ in <=1.10; so we've dropped support.
5343
#if (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR <= 10)
54-
#define DO_TORCH_FREEZE_HACK
55-
// For the hack, need more headers:
56-
#include <torch/csrc/jit/passes/freeze_module.h>
57-
#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
58-
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
59-
#include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
44+
#error "PyTorch version < 1.11 is not supported"
6045
#endif
6146

62-
6347
using namespace LAMMPS_NS;
6448

6549
PairNEQUIP::PairNEQUIP(LAMMPS *lmp) : Pair(lmp) {
@@ -92,13 +76,7 @@ void PairNEQUIP::init_style(){
9276
if (atom->tag_enable == 0)
9377
error->all(FLERR,"Pair style NEQUIP requires atom IDs");
9478

95-
// need a full neighbor list
96-
int irequest = neighbor->request(this,instance_me);
97-
neighbor->requests[irequest]->half = 0;
98-
neighbor->requests[irequest]->full = 1;
99-
100-
// TODO: probably also
101-
neighbor->requests[irequest]->ghost = 0;
79+
neighbor->add_request(this, NeighConst::REQ_FULL);
10280

10381
// TODO: I think Newton should be off, enforce this.
10482
// The network should just directly compute the total forces
@@ -125,7 +103,7 @@ void PairNEQUIP::allocate()
125103
}
126104

127105
void PairNEQUIP::settings(int narg, char ** /*arg*/) {
128-
// "flare" should be the only word after "pair_style" in the input file.
106+
// "nequip" should be the only word after "pair_style" in the input file.
129107
if (narg > 0)
130108
error->all(FLERR, "Illegal pair_style command");
131109
}
@@ -186,52 +164,23 @@ void PairNEQUIP::coeff(int narg, char **arg) {
186164
// This is the check used by PyTorch: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/api/module.cpp#L476
187165
if (model.hasattr("training")) {
188166
std::cout << "Freezing TorchScript model...\n";
189-
#ifdef DO_TORCH_FREEZE_HACK
190-
// Do the hack
191-
// Copied from the implementation of torch::jit::freeze,
192-
// except without the broken check
193-
// See https://github.com/pytorch/pytorch/blob/dfbd030854359207cb3040b864614affeace11ce/torch/csrc/jit/api/module.cpp
194-
bool optimize_numerics = true; // the default
195-
// the {} is preserved_attrs
196-
auto out_mod = freeze_module(
197-
model, {}
198-
);
199-
// See 1.11 bugfix in https://github.com/pytorch/pytorch/pull/71436
200-
auto graph = out_mod.get_method("forward").graph();
201-
OptimizeFrozenGraph(graph, optimize_numerics);
202-
model = out_mod;
203-
#else
204-
// Do it normally
205-
model = torch::jit::freeze(model);
206-
#endif
167+
model = torch::jit::freeze(model);
207168
}
208169

209-
#if (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR <= 10)
210-
// Set JIT bailout to avoid long recompilations for many steps
211-
size_t jit_bailout_depth;
212-
if (metadata["_jit_bailout_depth"].empty()) {
213-
// This is the default used in the Python code
214-
jit_bailout_depth = 2;
215-
} else {
216-
jit_bailout_depth = std::stoi(metadata["_jit_bailout_depth"]);
217-
}
218-
torch::jit::getBailoutDepth() = jit_bailout_depth;
219-
#else
220-
// In PyTorch >=1.11, this is now set_fusion_strategy
221-
torch::jit::FusionStrategy strategy;
222-
if (metadata["_jit_fusion_strategy"].empty()) {
223-
// This is the default used in the Python code
224-
strategy = {{torch::jit::FusionBehavior::DYNAMIC, 3}};
225-
} else {
226-
std::stringstream strat_stream(metadata["_jit_fusion_strategy"]);
227-
std::string fusion_type, fusion_depth;
228-
while(std::getline(strat_stream, fusion_type, ',')) {
229-
std::getline(strat_stream, fusion_depth, ';');
230-
strategy.push_back({fusion_type == "STATIC" ? torch::jit::FusionBehavior::STATIC : torch::jit::FusionBehavior::DYNAMIC, std::stoi(fusion_depth)});
231-
}
170+
// In PyTorch >=1.11, this is now set_fusion_strategy
171+
torch::jit::FusionStrategy strategy;
172+
if (metadata["_jit_fusion_strategy"].empty()) {
173+
// This is the default used in the Python code
174+
strategy = {{torch::jit::FusionBehavior::DYNAMIC, 3}};
175+
} else {
176+
std::stringstream strat_stream(metadata["_jit_fusion_strategy"]);
177+
std::string fusion_type, fusion_depth;
178+
while(std::getline(strat_stream, fusion_type, ',')) {
179+
std::getline(strat_stream, fusion_depth, ';');
180+
strategy.push_back({fusion_type == "STATIC" ? torch::jit::FusionBehavior::STATIC : torch::jit::FusionBehavior::DYNAMIC, std::stoi(fusion_depth)});
232181
}
233-
torch::jit::setFusionStrategy(strategy);
234-
#endif
182+
}
183+
torch::jit::setFusionStrategy(strategy);
235184

236185
// Set whether to allow TF32:
237186
bool allow_tf32;
@@ -463,40 +412,49 @@ void PairNEQUIP::compute(int eflag, int vflag){
463412
auto output = model.forward(input_vector).toGenericDict();
464413

465414
torch::Tensor forces_tensor = output.at("forces").toTensor().cpu();
466-
auto forces = forces_tensor.accessor<float, 2>();
415+
auto forces = forces_tensor.accessor<double, 2>();
467416

468417
torch::Tensor total_energy_tensor = output.at("total_energy").toTensor().cpu();
469418

470419
// store the total energy where LAMMPS wants it
471-
eng_vdwl = total_energy_tensor.data_ptr<float>()[0];
420+
eng_vdwl = total_energy_tensor.data_ptr<double>()[0];
472421

473422
torch::Tensor atomic_energy_tensor = output.at("atomic_energy").toTensor().cpu();
474-
auto atomic_energies = atomic_energy_tensor.accessor<float, 2>();
475-
float atomic_energy_sum = atomic_energy_tensor.sum().data_ptr<float>()[0];
423+
auto atomic_energies = atomic_energy_tensor.accessor<double, 2>();
424+
425+
if(vflag){
426+
torch::Tensor v_tensor = output.at("virial").toTensor().cpu();
427+
auto v = v_tensor.accessor<double, 3>();
428+
// Convert from 3x3 symmetric tensor format, which NequIP outputs, to the flattened form LAMMPS expects
429+
// First [0] index on v is batch
430+
virial[0] = v[0][0][0];
431+
virial[1] = v[0][1][1];
432+
virial[2] = v[0][2][2];
433+
virial[3] = v[0][0][1];
434+
virial[4] = v[0][0][2];
435+
virial[5] = v[0][1][2];
436+
}
437+
if(vflag_atom) {
438+
error->all(FLERR,"Pair style NEQUIP does not support per-atom virial");
439+
}
476440

477441
if(debug_mode){
478442
std::cout << "NequIP model output:\n";
479443
std::cout << "forces: " << forces_tensor << "\n";
480444
std::cout << "total_energy: " << total_energy_tensor << "\n";
481445
std::cout << "atomic_energy: " << atomic_energy_tensor << "\n";
446+
if(vflag) std::cout << "virial: " << output.at("virial").toTensor().cpu() << std::endl;
482447
}
483448

484-
//std::cout << "atomic energy sum: " << atomic_energy_sum << std::endl;
485-
//std::cout << "Total energy: " << total_energy_tensor << "\n";
486-
//std::cout << "atomic energy shape: " << atomic_energy_tensor.sizes()[0] << "," << atomic_energy_tensor.sizes()[1] << std::endl;
487-
//std::cout << "atomic energies: " << atomic_energy_tensor << std::endl;
488-
489449
// Write forces and per-atom energies (0-based tags here)
490450
for(int itag = 0; itag < inum; itag++){
491451
int i = tag2i[itag];
492452
f[i][0] = forces[itag][0];
493453
f[i][1] = forces[itag][1];
494454
f[i][2] = forces[itag][2];
495455
if (eflag_atom) eatom[i] = atomic_energies[itag][0];
496-
//printf("%d %d %g %g %g %g %g %g\n", i, type[i], pos[itag][0], pos[itag][1], pos[itag][2], f[i][0], f[i][1], f[i][2]);
497456
}
498457

499-
// TODO: Virial stuff? (If there even is a pairwise force concept here)
500458

501459
// TODO: Performance: Depending on how the graph network works, using tags for edges may lead to shitty memory access patterns and performance.
502460
// It may be better to first create tag2i as a separate loop, then set edges[edge_counter][:] = (i, tag2i[jtag]).

patch_lammps.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ fi
6262

6363
echo "Updating CMakeLists.txt..."
6464

65-
# Update CMakeLists.txt
6665
sed -i "s/set(CMAKE_CXX_STANDARD 11)/set(CMAKE_CXX_STANDARD 14)/" $lammps_dir/cmake/CMakeLists.txt
6766

6867
# Add libtorch
@@ -73,4 +72,4 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
7372
target_link_libraries(lammps PUBLIC "${TORCH_LIBRARIES}")
7473
EOF2
7574

76-
echo "Done!"
75+
echo "Done!"

tests/test_data/test_repro.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@ run_name: minimal
22
seed: 123
33
dataset_seed: 456
44

5+
# from minimal_stress.yaml
6+
model_builders:
7+
- SimpleIrrepsConfig
8+
- EnergyModel
9+
- PerSpeciesRescale
10+
- StressForceOutput
11+
- RescaleEnergyEtc
12+
513
# network
614
num_basis: 4
715
l_max: 1

tests/test_python_repro.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from collections import Counter
1313

1414
import ase
15+
import ase.units
1516
import ase.build
1617
import ase.io
1718

@@ -144,9 +145,11 @@ def test_repro(deployed_model):
144145
145146
compute atomicenergies all pe/atom
146147
compute totalatomicenergy all reduce sum c_atomicenergies
148+
compute stress all pressure NULL virial # NULL means without temperature contribution
147149
148-
thermo_style custom step time temp pe c_totalatomicenergy etotal press spcpu cpuremain
150+
thermo_style custom step time temp pe c_totalatomicenergy etotal press spcpu cpuremain c_stress[*]
149151
run 0
152+
print "$({PRECISION_CONST} * c_stress[1]) $({PRECISION_CONST} * c_stress[2]) $({PRECISION_CONST} * c_stress[3]) $({PRECISION_CONST} * c_stress[4]) $({PRECISION_CONST} * c_stress[5]) $({PRECISION_CONST} * c_stress[6])" file stress.dat
150153
print $({PRECISION_CONST} * pe) file pe.dat
151154
print $({PRECISION_CONST} * c_totalatomicenergy) file totalatomicenergy.dat
152155
write_dump all custom output.dump id type x y z fx fy fz c_atomicenergies modify format float %20.15g
@@ -309,9 +312,33 @@ def test_repro(deployed_model):
309312
float(Path(tmpdir + f"/totalatomicenergy.dat").read_text())
310313
/ PRECISION_CONST
311314
)
315+
# in `metal` units, pressure/stress has units bars
316+
# so need to convert
317+
lammps_stress = np.fromstring(
318+
Path(tmpdir + f"/stress.dat").read_text(), sep=" ", dtype=np.float64
319+
) * (ase.units.bar / PRECISION_CONST)
320+
# https://docs.lammps.org/compute_pressure.html
321+
# > The ordering of values in the symmetric pressure tensor is as follows: pxx, pyy, pzz, pxy, pxz, pyz.
322+
lammps_stress = np.array(
323+
[
324+
[lammps_stress[0], lammps_stress[3], lammps_stress[4]],
325+
[lammps_stress[3], lammps_stress[1], lammps_stress[5]],
326+
[lammps_stress[4], lammps_stress[5], lammps_stress[2]],
327+
]
328+
)
312329
assert np.allclose(lammps_pe, lammps_totalatomicenergy)
313330
assert np.allclose(
314331
structure.get_potential_energy(),
315332
lammps_pe,
316333
atol=1e-6,
317334
)
335+
if periodic:
336+
# In LAMMPS, the convention is that the stress tensor, and thus the pressure, is related to the virial
337+
# WITHOUT a sign change. In `nequip`, we chose currently to follow the virial = -stress x volume
338+
# convention => stress = -1/V * virial. ASE does not change the sign of the virial, so we have
339+
# to flip the sign from ASE for the comparison.
340+
assert np.allclose(
341+
-structure.get_stress(voigt=False),
342+
lammps_stress,
343+
atol=1e-6,
344+
)

0 commit comments

Comments
 (0)