Skip to content

Commit c1f3bad

Browse files
authored
Merge pull request #348 from libAtoms/err_calc_stress_virial_voigt
Convert virial/stress to Voigt-6 when computing error table
2 parents 5b08766 + 19f30be commit c1f3bad

File tree

5 files changed

+65
-10
lines changed

5 files changed

+65
-10
lines changed

.github/workflows/pytests.yml

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ jobs:
1414
runs-on: ubuntu-latest
1515
strategy:
1616
matrix:
17-
python-version: [ "3.9" ]
17+
python-version: [ "3.10" ]
1818
max-parallel: 5
1919
env:
20-
coverage-on-version: "3.9"
20+
coverage-on-version: "3.10"
2121
use-mpi: True
2222
PIP_CONSTRAINT: pip_constraint.txt
2323
defaults:
@@ -42,13 +42,15 @@ jobs:
4242
run: |
4343
echo "numpy<2" >> $PIP_CONSTRAINT
4444
python3 -m pip install wheel setuptools numpy scipy click matplotlib pyyaml spglib rdkit==2024.3.3 flake8 pytest pytest-cov requests
45+
python3 -c "import numpy; print('numpy version', numpy.__version__)"
4546
4647
- name: Install latest ASE from pypi
4748
run: |
4849
echo PIP_CONSTRAINT $PIP_CONSTRAINT
4950
python3 -m pip install ase
5051
echo -n "ASE VERSION "
5152
python3 -c "import ase; print(ase.__file__, ase.__version__)"
53+
python3 -c "import numpy; print('numpy version', numpy.__version__)"
5254
5355
- name: Install intel-oneapi-mkl for phono3py
5456
run: |
@@ -61,6 +63,7 @@ jobs:
6163
sudo apt update
6264
sudo apt install intel-oneapi-mkl
6365
sudo apt install intel-oneapi-mkl-devel
66+
python3 -c "import numpy; print('numpy version', numpy.__version__)"
6467
6568
- name: Install phono3py from source
6669
run: |
@@ -90,16 +93,22 @@ jobs:
9093
cd phono3py
9194
python3 -m pip install -e . -vvv
9295
cd ..
96+
python3 -c "import numpy; print('numpy version', numpy.__version__)"
9397
9498
- name: Install Quippy from PyPI
95-
run: python3 -m pip install quippy-ase
99+
run: |
100+
python3 -m pip install quippy-ase
101+
python3 -c "import numpy; print('numpy version', numpy.__version__)"
96102
97103
- name: Install xTB (before things that need pandas like MACE and wfl, since it will break pandas-numpy compatibility by downgrading numpy)
98104
run: |
105+
# force compatible numpy version
106+
conda install 'numpy<2'
99107
conda install -c conda-forge xtb-python
100108
python3 -m pip install typing-extensions
101109
# install pandas now to encourage compatible numpy version after conda regressed it
102110
python3 -m pip install pandas
111+
python3 -c "import numpy; print('numpy version', numpy.__version__)"
103112
104113
- name: MACE
105114
run: |
@@ -136,9 +145,11 @@ jobs:
136145
fi
137146
echo "found torch version ${torch_version}+cpu, installing"
138147
python3 -m pip install torch==${torch_version}+cpu -f https://download.pytorch.org/whl/torch_stable.html
148+
python3 -c "import numpy; print('numpy version', numpy.__version__)"
139149
echo "installing mace"
140150
python3 -m pip install git+https://github.com/ACEsuit/mace.git@main
141-
python3 -c "import mace; print(mace.__file__)"
151+
python3 -c "import mace; print('mace file', mace.__file__)"
152+
python3 -c "import numpy; print('numpy version', numpy.__version__)"
142153
143154
- name: Julia and ace fit
144155
run: |
@@ -148,19 +159,24 @@ jobs:
148159
# note that this hardwires a particular compatible ACE1pack version
149160
echo 'using Pkg; pkg"registry add https://github.com/JuliaRegistries/General"; pkg"registry add https://github.com/JuliaMolSim/MolSim.git"; pkg"add ACE1pack@0.0, ACE1, JuLIP, IPFitting, ASE"' > ace1pack_install.jl
150161
${PWD}/julia-1.8.1/bin/julia ace1pack_install.jl
162+
python3 -c "import numpy; print('numpy version', numpy.__version__)"
151163
152164
- name: Install wfl (expyre and universalSOAP are dependencies)
153-
run: python3 -m pip install .
165+
run: |
166+
python3 -m pip install .
167+
python3 -c "import numpy; print('numpy version', numpy.__version__)"
154168
155169
- name: Install Quantum Espresso
156170
run: |
157171
sudo apt-get install --no-install-recommends quantum-espresso
172+
python3 -c "import numpy; print('numpy version', numpy.__version__)"
158173
159174
- name: Install MOPAC
160175
run: |
161-
wget http://openmopac.net/mopac-22.1.1-linux.tar.gz
176+
wget https://github.com/openmopac/mopac/releases/download/v22.1.1/mopac-22.1.1-linux.tar.gz
162177
tar -xzvf mopac-22.1.1-linux.tar.gz
163178
echo $GITHUB_WORKSPACE/mopac-22.1.1-linux/bin >> $GITHUB_PATH
179+
python3 -c "import numpy; print('numpy version', numpy.__version__)"
164180
165181
- name: Install buildcell
166182
run: |
@@ -172,6 +188,7 @@ jobs:
172188
mkdir -p $HOME/bin
173189
cp src/buildcell/src/buildcell $HOME/bin/
174190
cd ..
191+
python3 -c "import numpy; print('numpy version', numpy.__version__)"
175192
176193
- name: Add buildcell to system path
177194
run: |
@@ -183,6 +200,7 @@ jobs:
183200
# this can eaily be turned off if needed
184201
conda install -c conda-forge mpi4py openmpi pytest-mpi
185202
python3 -m pip install mpipool
203+
python3 -c "import numpy; print('numpy version', numpy.__version__)"
186204
187205
- name: Install and configure slurm and ExPyRe
188206
run: |
@@ -206,6 +224,7 @@ jobs:
206224
sinfo -s --long
207225
mkdir $HOME/.expyre
208226
cp .github/workflows_assets/config.json $HOME/.expyre
227+
python3 -c "import numpy; print('numpy version', numpy.__version__)"
209228
210229
- name: Set up pw.x for running in wfl
211230
run: |
@@ -221,6 +240,7 @@ jobs:
221240
222241
echo 'post-espresso $HOME/.config/ase/config.ini'
223242
cat $HOME/.config/ase/config.ini
243+
python3 -c "import numpy; print('numpy version', numpy.__version__)"
224244
225245
- name: Lint with flake8
226246
run: |
@@ -234,6 +254,8 @@ jobs:
234254
run: |
235255
rm -rf $HOME/pytest_plain
236256
mkdir $HOME/pytest_plain
257+
# attempt to work around mkl/numpy issue
258+
export MKL_THREADING_LAYER=GNU
237259
#
238260
export EXPYRE_PYTEST_SYSTEMS=github
239261
export WFL_PYTEST_BUILDCELL=$HOME/bin/buildcell
@@ -247,6 +269,8 @@ jobs:
247269
run: |
248270
rm -rf $HOME/pytest_cov
249271
mkdir $HOME/pytest_cov
272+
# attempt to work around mkl/numpy issue
273+
export MKL_THREADING_LAYER=GNU
250274
#
251275
export EXPYRE_PYTEST_SYSTEMS=github
252276
export WFL_PYTEST_BUILDCELL=$HOME/bin/buildcell

tests/test_error.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66
from ase.atoms import Atoms
77
from ase.calculators.lj import LennardJones
8+
from ase.stress import voigt_6_to_full_3x3_stress
89
from pytest import approx
910

1011
from pprint import pprint
@@ -103,6 +104,17 @@ def test_err_from_calc(ref_atoms):
103104
assert ref_err_dict['virial/atom/comp']['_ALL_']["count"] == 10 * 6
104105

105106

107+
def test_err_stress_shape(ref_atoms):
108+
ref_atoms_calc = generic_calc(ref_atoms, OutputSpec(), LennardJones(sigma=0.75), output_prefix='calc_')
109+
ref_err_dict, _, _ = ref_err_calc(ref_atoms_calc, ref_property_prefix='REF_', calc_property_prefix='calc_')
110+
111+
for at in ref_atoms_calc:
112+
at.info["REF_stress"] = voigt_6_to_full_3x3_stress(at.info["REF_stress"])
113+
ref_err_dict_shape, _, _ = ref_err_calc(ref_atoms_calc, ref_property_prefix='REF_', calc_property_prefix='calc_')
114+
115+
assert ref_err_dict == ref_err_dict_shape
116+
117+
106118
def test_error_properties(ref_atoms):
107119
ref_atoms_calc = generic_calc(ref_atoms, OutputSpec(), LennardJones(sigma=0.75), output_prefix='calc_')
108120
# both energy and per atom

tests/test_md.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def test_md_attach_logger(cu_slab, tmp_path, monkeypatch):
268268
workdir = Path(os.getcwd())
269269

270270
assert len(atoms_traj) == 602
271-
assert all([Path(workdir / "test_log.item_0").is_file(), Path(workdir / "test_log.item_1").is_file()])
271+
assert all([Path(workdir / "test_log.config_0").is_file(), Path(workdir / "test_log.config_1").is_file()])
272272

273273

274274
def test_md_attach_logger_stdout(cu_slab, tmp_path, monkeypatch, capsys):

wfl/fit/error.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from matplotlib.figure import Figure
88
from matplotlib.pyplot import get_cmap
99

10+
from ase.stress import full_3x3_to_voigt_6_stress
11+
1012

1113
def calc(inputs, calc_property_prefix, ref_property_prefix,
1214
config_properties=None, atom_properties=None, category_keys="config_type",
@@ -83,8 +85,21 @@ def _reshape_normalize(quant, prop, atoms, per_atom):
8385
quant: 2-d array containing reshaped quantity, with leading dimension 1 for per-config
8486
or len(atoms) for per-atom
8587
"""
86-
# convert scalars or lists into arrays
87-
quant = np.asarray(quant)
88+
89+
# fix shape of stress/virial
90+
if prop.startswith("stress") or prop.startswith("virial"):
91+
if prop.split("/")[0] in ["stress", "virial"]:
92+
if quant.shape != (6,):
93+
if quant.shape not in [(9,), (3,3)]:
94+
raise ValueError(f"Prop '{prop}' has unknown shape of quant {quant.shape}")
95+
quant = full_3x3_to_voigt_6_stress(quant.reshape((3, 3)))
96+
elif prop.split("/")[0] in ["stresses", "virials"]:
97+
eff_quant_shape = quant.shape[1:]
98+
if eff_quant_shape != (6,):
99+
if eff_quant_shape not in [(9,), (3,3)]:
100+
raise ValueError(f"Prop '{prop}' has unknown shape of quant {quant.shape}")
101+
quant = [full_3x3_to_voigt_6_stress(q.reshape((3, 3))) for q in quant]
102+
quant = np.asarray(quant)
88103

89104
# Reshape to 2-d, with leading dimension 1 for per-config, and len(atoms) for per-atom.
90105
# This is the right shape to work with later flattening for per-property and norm calculation
@@ -150,6 +165,10 @@ def _reshape_normalize(quant, prop, atoms, per_atom):
150165

151166
continue
152167

168+
# make a copy so normalization doesn't affect original
169+
ref_quant = np.asarray(ref_quant).copy()
170+
calc_quant = np.asarray(calc_quant).copy()
171+
153172
if virial_from_stress:
154173
# ref quant was actually stress, automatically convert
155174
ref_quant *= -at.get_volume()

wfl/fit/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import shlex
66
import warnings
77

8-
from ase.constraints import voigt_6_to_full_3x3_stress
8+
from ase.stress import voigt_6_to_full_3x3_stress
99

1010
from wfl.utils.julia import julia_exec_path
1111

0 commit comments

Comments
 (0)