Skip to content

Commit 01df21e

Browse files
committed
Refactor ForceField class to facilitate future extensions
1 parent d04ee8b commit 01df21e

File tree

6 files changed

+84
-54
lines changed

6 files changed

+84
-54
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Changed
1111

12-
- Refactored neighborlist API, to prepare for more efficient implementations.
12+
- Refactor `ForceField` class to facilitate future extensions.
13+
- Refactor neighborlist API, to prepare for more efficient implementations.
1314

1415

1516
## [0.2.2] - 2024-10-09

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,15 @@ This is the part that you are expected to write.
5858
The evaluation of the force field energy and its derivatives requires the following:
5959

6060
```python
61-
from tinyff.forcefield import CutoffWrapper, LennardJones, PairwiseForceField
61+
from tinyff.forcefield import CutoffWrapper, LennardJones, ForceField
6262
from tinyff.neighborlist import NBuildSimple
6363

6464
# Define a pairwise potential, with energy and force shift
6565
rcut = 5.0
6666
lj = CutOffWrapper(LennardJones(2.5, 2.0), rcut)
6767

6868
# Define a force field
69-
pwff = PairwiseForceField(lj, NBuildSimple(rcut))
69+
ff = ForceField([lj], NBuildSimple(rcut))
7070

7171
# You need atomic positions and the length of a periodic cell edge.
7272
# The following line defines just two atomic positions.
@@ -80,11 +80,11 @@ cell_length = 20.0
8080
# - An array with Cartesian forces, same shape as `atpos`.
8181
# - The force contribution the pressure
8282
# (often the written as the second term in the virial pressure).
83-
potential_energy, forces, frc_pressure = pwff(atpos, cell_length)
83+
potential_energy, forces, frc_pressure = ff(atpos, cell_length)
8484
```
8585

8686
This basic recipe can be extended by passing additional options
87-
into the `PairwiseForceField` constructor:
87+
into the `ForceField` constructor:
8888

8989
- Linear-scaling neighbor lists with the
9090
[cell lists](https://en.wikipedia.org/wiki/Cell_lists) method:
@@ -93,7 +93,7 @@ into the `PairwiseForceField` constructor:
9393
from tinyff.neighborlist import NBuildCellLists
9494

9595
# Construct your force field object as follows:
96-
pwff = PairwiseForceField(lj, NBuildCellLists(rcut))
96+
ff = ForceField([lj], NBuildCellLists(rcut))
9797
```
9898

9999
Note that the current cell lists implementation is not very efficient (yet),
@@ -104,7 +104,7 @@ into the `PairwiseForceField` constructor:
104104

105105
```python
106106
rmax = 6.0 # > rcut, so buffer of 1.0
107-
pwff = PairwiseForceField(lj, NBuildSimple(rmax, nlist_reuse=16))
107+
ff = ForceField([lj], NBuildSimple(rmax, nlist_reuse=16))
108108
```
109109

110110

src/tinyff/atomsmithy.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from numpy.typing import ArrayLike, NDArray
2424
from scipy.optimize import minimize
2525

26-
from .forcefield import PairPotential, PairwiseForceField
26+
from .forcefield import ForceField, PairPotential
2727
from .neighborlist import NBuild, NBuildSimple
2828

2929
__all__ = (
@@ -86,7 +86,7 @@ class PushPotential(PairPotential):
8686

8787
rcut: float = attrs.field(converter=float, validator=attrs.validators.gt(0))
8888

89-
def __call__(self, dist: ArrayLike) -> tuple[NDArray, NDArray]:
89+
def compute(self, dist: ArrayLike) -> tuple[NDArray, NDArray]:
9090
"""Compute pair potential energy and its derivative towards distance."""
9191
dist = np.asarray(dist, dtype=float)
9292
x = dist / self.rcut
@@ -113,11 +113,11 @@ def build_random_cell(
113113
# Define cost function to push the atoms appart.
114114
if nbuild is None:
115115
nbuild = NBuildSimple(rcut)
116-
pwff = PairwiseForceField(PushPotential(rcut), nbuild)
116+
ff = ForceField([PushPotential(rcut)], nbuild)
117117

118118
def costgrad(atpos_raveled):
119119
atpos = atpos_raveled.reshape(-1, 3)
120-
energy, force, _ = pwff(atpos, cell_length)
120+
energy, force, _ = ff(atpos, cell_length)
121121
return energy, -force.ravel()
122122

123123
# Optimize and return structure

src/tinyff/forcefield.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,25 @@
2222
import numpy as np
2323
from numpy.typing import ArrayLike, NDArray
2424

25-
from .neighborlist import NBuild
25+
from .neighborlist import NLIST_DTYPE, NBuild
2626

27-
__all__ = ("PairwiseForceField",)
27+
__all__ = ("ForceTerm", "PairPotential", "LennardJones", "CutOffWrapper", "ForceField")
28+
29+
30+
@attrs.define
31+
class ForceTerm:
32+
def __call__(self, nlist: NDArray[NLIST_DTYPE]):
33+
raise NotImplementedError # pragma: nocover
2834

2935

3036
@attrs.define
3137
class PairPotential:
32-
def __call__(self, dist: ArrayLike) -> tuple[NDArray, NDArray]:
38+
def __call__(self, nlist: NDArray[NLIST_DTYPE]):
39+
energy, gdist = self.compute(nlist["dist"])
40+
nlist["energy"] += energy
41+
nlist["gdist"] += gdist
42+
43+
def compute(self, dist: ArrayLike) -> tuple[NDArray, NDArray]:
3344
"""Compute pair potential energy and its derivative towards distance."""
3445
raise NotImplementedError # pragma: nocover
3546

@@ -39,7 +50,7 @@ class LennardJones(PairPotential):
3950
epsilon: float = attrs.field(converter=float)
4051
sigma: float = attrs.field(converter=float)
4152

42-
def __call__(self, dist: ArrayLike) -> tuple[NDArray, NDArray]:
53+
def compute(self, dist: ArrayLike) -> tuple[NDArray, NDArray]:
4354
"""Compute pair potential energy and its derivative towards distance."""
4455
dist = np.asarray(dist, dtype=float)
4556
x = self.sigma / dist
@@ -57,23 +68,23 @@ class CutOffWrapper(PairPotential):
5768

5869
def __attrs_post_init__(self):
5970
"""Post initialization changes."""
60-
self.ecut, self.gcut = self.original(self.rcut)
71+
self.ecut, self.gcut = self.original.compute(self.rcut)
6172

62-
def __call__(self, dist: ArrayLike) -> tuple[NDArray, NDArray]:
73+
def compute(self, dist: ArrayLike) -> tuple[NDArray, NDArray]:
6374
"""Compute pair potential energy and its derivative towards distance."""
6475
dist = np.asarray(dist, dtype=float)
6576
mask = dist < self.rcut
6677
if mask.ndim == 0:
6778
# Deal with non-array case
6879
if mask:
69-
energy, gdist = self.original(dist)
80+
energy, gdist = self.original.compute(dist)
7081
energy -= self.ecut + self.gcut * (dist - self.rcut)
7182
gdist -= self.gcut
7283
else:
7384
energy = 0.0
7485
gdist = 0.0
7586
else:
76-
energy, gdist = self.original(dist)
87+
energy, gdist = self.original.compute(dist)
7788
energy[mask] -= self.ecut + self.gcut * (dist[mask] - self.rcut)
7889
energy[~mask] = 0.0
7990
gdist[mask] -= self.gcut
@@ -82,11 +93,9 @@ def __call__(self, dist: ArrayLike) -> tuple[NDArray, NDArray]:
8293

8394

8495
@attrs.define
85-
class PairwiseForceField:
86-
pair_potential: PairPotential = attrs.field(
87-
validator=attrs.validators.instance_of(PairPotential)
88-
)
89-
"""A definition of the pair potential."""
96+
class ForceField:
97+
force_terms: list[ForceTerm] = attrs.field()
98+
"""A list of contributions to the potential energy."""
9099

91100
nbuild: NBuild = attrs.field(validator=attrs.validators.instance_of(NBuild))
92101
"""Algorithm to build the neigborlist."""
@@ -116,10 +125,11 @@ def __call__(self, atpos: NDArray, cell_length: float):
116125
self.nbuild.update(atpos, cell_length)
117126
nlist = self.nbuild.nlist
118127
# Compute all pairwise quantities
119-
nlist["energy"], nlist["gdist"] = self.pair_potential(nlist["dist"])
120-
nlist["gdelta"] = (nlist["gdist"] / nlist["dist"]).reshape(-1, 1) * nlist["delta"]
128+
for force_term in self.force_terms:
129+
force_term(nlist)
121130
# Compute the totals
122131
energy = nlist["energy"].sum()
132+
nlist["gdelta"] = (nlist["gdist"] / nlist["dist"]).reshape(-1, 1) * nlist["delta"]
123133
forces = np.zeros(atpos.shape, dtype=float)
124134
np.add.at(forces, nlist["iatom0"], nlist["gdelta"])
125135
np.add.at(forces, nlist["iatom1"], -nlist["gdelta"])

tests/test_atomsmithy.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,18 +112,18 @@ def test_fcc_lattice():
112112
def test_push_derivative():
113113
pp = PushPotential(2.5)
114114
dist = np.linspace(0.4, 3.0, 50)
115-
gdist1 = pp(dist)[1]
116-
gdist2 = nd.Derivative(lambda dist: pp(dist)[0])(dist)
115+
gdist1 = pp.compute(dist)[1]
116+
gdist2 = nd.Derivative(lambda dist: pp.compute(dist)[0])(dist)
117117
assert gdist1 == pytest.approx(gdist2)
118118

119119

120120
def test_push_cutoff():
121121
pp = PushPotential(2.5)
122122
eps = 1e-13
123-
assert abs(pp(2.5 - 0.1)[0]) > eps
124-
assert abs(pp(2.5 - 0.1)[1]) > eps
125-
assert abs(pp(2.5 - eps)[0]) < eps
126-
assert abs(pp(2.5 - eps)[1]) < eps
123+
assert abs(pp.compute(2.5 - 0.1)[0]) > eps
124+
assert abs(pp.compute(2.5 - 0.1)[1]) > eps
125+
assert abs(pp.compute(2.5 - eps)[0]) < eps
126+
assert abs(pp.compute(2.5 - eps)[1]) < eps
127127

128128

129129
def test_random_box():

tests/test_forcefield.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,36 +22,37 @@
2222
import numpy as np
2323
import pytest
2424

25-
from tinyff.forcefield import CutOffWrapper, LennardJones, PairwiseForceField
25+
from tinyff.atomsmithy import PushPotential
26+
from tinyff.forcefield import CutOffWrapper, ForceField, LennardJones
2627
from tinyff.neighborlist import NBuildCellLists, NBuildSimple
2728

2829

2930
def test_lennard_jones_derivative():
3031
lj = LennardJones(2.5, 0.5)
3132
dist = np.linspace(0.4, 3.0, 50)
32-
gdist1 = lj(dist)[1]
33-
gdist2 = nd.Derivative(lambda dist: lj(dist)[0])(dist)
33+
gdist1 = lj.compute(dist)[1]
34+
gdist2 = nd.Derivative(lambda dist: lj.compute(dist)[0])(dist)
3435
assert gdist1 == pytest.approx(gdist2)
3536

3637

3738
def test_lennard_jones_cut_derivative():
3839
lj = CutOffWrapper(LennardJones(2.5, 0.5), 3.5)
3940
dist = np.linspace(0.4, 5.0, 50)
40-
gdist1 = lj(dist)[1]
41-
gdist2 = nd.Derivative(lambda x: lj(x)[0])(dist)
41+
gdist1 = lj.compute(dist)[1]
42+
gdist2 = nd.Derivative(lambda x: lj.compute(x)[0])(dist)
4243
assert gdist1 == pytest.approx(gdist2)
4344

4445

4546
def test_lennard_jones_cut_zero_array():
4647
lj = CutOffWrapper(LennardJones(2.5, 0.5), 3.5)
47-
e, g = lj([5.0, 3.6])
48+
e, g = lj.compute([5.0, 3.6])
4849
assert (e == 0.0).all()
4950
assert (g == 0.0).all()
5051

5152

5253
def test_lennard_jones_cut_zero_scalar():
5354
lj = CutOffWrapper(LennardJones(2.5, 0.5), 3.5)
54-
e, g = lj(5.0)
55+
e, g = lj.compute(5.0)
5556
assert e == 0.0
5657
assert g == 0.0
5758

@@ -65,13 +66,12 @@ def test_pairwise_force_field_two(nbuild_class):
6566
# Define the force field.
6667
rcut = 8.0
6768
lj = CutOffWrapper(LennardJones(2.5, 1.3), rcut)
68-
nbuild = nbuild_class(rcut)
69-
pwff = PairwiseForceField(lj, nbuild)
69+
ff = ForceField([lj], nbuild_class(rcut))
7070

7171
# Compute and check against manual result
72-
energy, forces, frc_press = pwff(atpos, cell_length)
72+
energy, forces, frc_press = ff(atpos, cell_length)
7373
d = np.linalg.norm(atpos[0] - atpos[1])
74-
e, g = lj(d)
74+
e, g = lj.compute(d)
7575
assert energy == pytest.approx(e)
7676
assert forces == pytest.approx(np.array([[g, 0.0, 0.0], [-g, 0.0, 0.0]]))
7777
assert frc_press == pytest.approx(-g * d / (3 * cell_length**3))
@@ -86,31 +86,30 @@ def test_pairwise_force_field_three(nbuild_class):
8686
# Define the force field.
8787
rcut = 8.0
8888
lj = CutOffWrapper(LennardJones(2.5, 1.3), rcut)
89-
nbuild = nbuild_class(rcut)
90-
pwff = PairwiseForceField(lj, nbuild)
89+
ff = ForceField([lj], nbuild_class(rcut))
9190

9291
# Compute the energy, the forces and the force contribution pressure.
93-
energy1, forces1, frc_press1 = pwff(atpos, cell_length)
92+
energy1, forces1, frc_press1 = ff(atpos, cell_length)
9493

9594
# Compute the energy manually and compare.
9695
dists = [
9796
np.linalg.norm(atpos[1] - atpos[2]),
9897
np.linalg.norm(atpos[2] - atpos[0]),
9998
np.linalg.norm(atpos[0] - atpos[1]),
10099
]
101-
energy2 = lj(dists)[0].sum()
100+
energy2 = lj.compute(dists)[0].sum()
102101
assert energy1 == pytest.approx(energy2)
103102

104103
# Test forces with numdifftool
105-
forces2 = -nd.Gradient(lambda x: pwff(x.reshape(-1, 3), cell_length)[0])(atpos)
104+
forces2 = -nd.Gradient(lambda x: ff(x.reshape(-1, 3), cell_length)[0])(atpos)
106105
forces2.shape = (-1, 3)
107106
assert forces1 == pytest.approx(forces2.reshape(-1, 3))
108107

109108
# Test pressure with numdifftool
110109
def energy_volume(volume):
111110
my_cell_length = volume ** (1.0 / 3.0)
112111
scale = my_cell_length / cell_length
113-
return pwff(atpos * scale, my_cell_length)[0]
112+
return ff(atpos * scale, my_cell_length)[0]
114113

115114
frc_press2 = -nd.Derivative(energy_volume)(cell_length**3)
116115
assert frc_press1 == pytest.approx(frc_press2)
@@ -143,23 +142,43 @@ def test_pairwise_force_field_fifteen(nbuild_class):
143142
# Define the force field.
144143
rcut = 8.0
145144
lj = CutOffWrapper(LennardJones(2.5, 1.3), rcut)
146-
nbuild = nbuild_class(rcut)
147-
pwff = PairwiseForceField(lj, nbuild)
145+
ff = ForceField([lj], nbuild_class(rcut))
148146

149147
# Compute the energy, the forces and the force contribution to the pressure.
150-
energy, forces1, frc_press1 = pwff(atpos, cell_length)
148+
energy, forces1, frc_press1 = ff(atpos, cell_length)
151149
assert energy < 0
152150

153151
# Test forces with numdifftool
154-
forces2 = -nd.Gradient(lambda x: pwff(x.reshape(-1, 3), cell_length)[0])(atpos)
152+
forces2 = -nd.Gradient(lambda x: ff(x.reshape(-1, 3), cell_length)[0])(atpos)
155153
forces2.shape = (-1, 3)
156154
assert forces1 == pytest.approx(forces2.reshape(-1, 3))
157155

158156
# Test pressure with numdifftool
159157
def energy_volume(volume):
160158
my_cell_length = volume ** (1.0 / 3.0)
161159
scale = my_cell_length / cell_length
162-
return pwff(atpos * scale, my_cell_length)[0]
160+
return ff(atpos * scale, my_cell_length)[0]
163161

164162
frc_press2 = -nd.Derivative(energy_volume)(cell_length**3)
165163
assert frc_press1 == pytest.approx(frc_press2)
164+
165+
166+
@pytest.mark.parametrize("nbuild_class", [NBuildSimple, NBuildCellLists])
167+
def test_superposition(nbuild_class):
168+
atpos = np.array([[0.0, 0.0, 5.0], [0.0, 0.0, 0.0], [0.0, 3.0, 0.0]])
169+
cell_length = 10.0
170+
171+
# Define the force field.
172+
rcut = 4.9
173+
lj = CutOffWrapper(LennardJones(2.5, 1.3), rcut)
174+
pp = PushPotential(rcut)
175+
ff_lj = ForceField([lj], nbuild_class(rcut))
176+
ff_pp = ForceField([pp], nbuild_class(rcut))
177+
ff_su = ForceField([lj, pp], nbuild_class(rcut))
178+
179+
energy_lj, forces_lj, frc_press_lj = ff_lj(atpos, cell_length)
180+
energy_pp, forces_pp, frc_press_pp = ff_pp(atpos, cell_length)
181+
energy_su, forces_su, frc_press_su = ff_su(atpos, cell_length)
182+
assert energy_lj + energy_pp == pytest.approx(energy_su)
183+
assert forces_lj + forces_pp == pytest.approx(forces_su)
184+
assert frc_press_lj + frc_press_pp == pytest.approx(frc_press_su)

0 commit comments

Comments
 (0)