Skip to content

Commit f242cbe

Browse files
Merge pull request #221 from amcadmus/master
Merge bug fixings to master
2 parents 1d1084d + 6569780 commit f242cbe

File tree

5 files changed

+42
-14
lines changed

5 files changed

+42
-14
lines changed

.github/workflows/test_import.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
name: test Python import
2+
3+
on:
4+
- push
5+
- pull_request
6+
7+
jobs:
8+
build:
9+
runs-on: ubuntu-latest
10+
steps:
11+
- uses: actions/checkout@v2
12+
- uses: actions/setup-python@v2
13+
with:
14+
python-version: '3.9'
15+
architecture: 'x64'
16+
- run: python -m pip install .
17+
- run: python -c 'import dpdata'
18+

dpdata/deepmd/comp.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@ def _cond_load_data(fname) :
88
tmp = np.load(fname)
99
return tmp
1010

11-
def _load_set(folder) :
12-
cells = np.load(os.path.join(folder, 'box.npy'))
11+
def _load_set(folder, nopbc: bool) :
1312
coords = np.load(os.path.join(folder, 'coord.npy'))
13+
if nopbc:
14+
cells = np.zeros((coords.shape[0], 3,3))
15+
else:
16+
cells = np.load(os.path.join(folder, 'box.npy'))
1417
eners = _cond_load_data(os.path.join(folder, 'energy.npy'))
1518
forces = _cond_load_data(os.path.join(folder, 'force.npy'))
1619
virs = _cond_load_data(os.path.join(folder, 'virial.npy'))
@@ -22,14 +25,16 @@ def to_system_data(folder,
2225
# data is empty
2326
data = load_type(folder, type_map = type_map)
2427
data['orig'] = np.zeros([3])
28+
if os.path.isfile(os.path.join(folder, "nopbc")):
29+
data['nopbc'] = True
2530
sets = sorted(glob.glob(os.path.join(folder, 'set.*')))
2631
all_cells = []
2732
all_coords = []
2833
all_eners = []
2934
all_forces = []
3035
all_virs = []
3136
for ii in sets :
32-
cells, coords, eners, forces, virs = _load_set(ii)
37+
cells, coords, eners, forces, virs = _load_set(ii, data.get('nopbc', False))
3338
nframes = np.reshape(cells, [-1,3,3]).shape[0]
3439
all_cells.append(np.reshape(cells, [nframes,3,3]))
3540
all_coords.append(np.reshape(coords, [nframes,-1,3]))
@@ -50,8 +55,6 @@ def to_system_data(folder,
5055
data['forces'] = np.concatenate(all_forces, axis = 0)
5156
if len(all_virs) > 0:
5257
data['virials'] = np.concatenate(all_virs, axis = 0)
53-
if os.path.isfile(os.path.join(folder, "nopbc")):
54-
data['nopbc'] = True
5558
return data
5659

5760

dpdata/deepmd/raw.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,13 @@ def to_system_data(folder, type_map = None, labels = True) :
3333
if os.path.isdir(folder) :
3434
data = load_type(folder, type_map = type_map)
3535
data['orig'] = np.zeros([3])
36-
data['cells'] = np.loadtxt(os.path.join(folder, 'box.raw'))
37-
data['coords'] = np.loadtxt(os.path.join(folder, 'coord.raw'))
38-
data['cells'] = np.reshape(data['cells'], [-1, 3, 3])
39-
nframes = data['cells'].shape[0]
36+
data['coords'] = np.loadtxt(os.path.join(folder, 'coord.raw'), ndmin=2)
37+
nframes = data['coords'].shape[0]
38+
if os.path.isfile(os.path.join(folder, "nopbc")):
39+
data['nopbc'] = True
40+
data['cells'] = np.zeros((nframes, 3,3))
41+
else:
42+
data['cells'] = np.loadtxt(os.path.join(folder, 'box.raw'), ndmin=2)
4043
data['cells'] = np.reshape(data['cells'], [nframes, 3, 3])
4144
data['coords'] = np.reshape(data['coords'], [nframes, -1, 3])
4245
if labels :

dpdata/pymatgen/molecule.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import numpy as np
2-
from pymatgen.core import Molecule
2+
try:
3+
from pymatgen.core import Molecule
4+
except ImportError:
5+
pass
36
from collections import Counter
47
import dpdata
58

tests/comp_sys.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ def test_nframs(self):
3535
def test_cell(self):
3636
self.assertEqual(self.system_1.get_nframes(),
3737
self.system_2.get_nframes())
38-
np.testing.assert_almost_equal(self.system_1.data['cells'],
39-
self.system_2.data['cells'],
40-
decimal = self.places,
41-
err_msg = 'cell failed')
38+
if not self.system_1.nopbc and not self.system_2.nopbc:
39+
np.testing.assert_almost_equal(self.system_1.data['cells'],
40+
self.system_2.data['cells'],
41+
decimal = self.places,
42+
err_msg = 'cell failed')
4243

4344
def test_coord(self):
4445
self.assertEqual(self.system_1.get_nframes(),

0 commit comments

Comments
 (0)