Skip to content

Commit 3c595ca

Browse files
authored
use np.testing.assert_almost_equal to compare numpy arrays (#201)
1 parent 7232d7d commit 3c595ca

File tree

1 file changed

+21
-34
lines changed

1 file changed

+21
-34
lines changed

tests/comp_sys.py

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,22 @@ def test_nframs(self):
3535
def test_cell(self):
3636
self.assertEqual(self.system_1.get_nframes(),
3737
self.system_2.get_nframes())
38-
for ff in range(self.system_1.get_nframes()) :
39-
for ii in range(3) :
40-
for jj in range(3) :
41-
self.assertAlmostEqual(self.system_1.data['cells'][ff][ii][jj],
42-
self.system_2.data['cells'][ff][ii][jj],
43-
places = self.places,
44-
msg = 'cell[%d][%d][%d] failed' % (ff,ii,jj))
38+
np.testing.assert_almost_equal(self.system_1.data['cells'],
39+
self.system_2.data['cells'],
40+
decimal = self.places,
41+
err_msg = 'cell failed')
4542

4643
def test_coord(self):
4744
self.assertEqual(self.system_1.get_nframes(),
4845
self.system_2.get_nframes())
4946
# think about direct coord
5047
tmp_cell = self.system_1.data['cells']
5148
tmp_cell = np.reshape(tmp_cell, [-1, 3])
52-
tmp_cell_norm = np.reshape(np.linalg.norm(tmp_cell, axis = 1), [-1, 3])
53-
for ff in range(self.system_1.get_nframes()) :
54-
for ii in range(sum(self.system_1.data['atom_numbs'])) :
55-
for jj in range(3) :
56-
self.assertAlmostEqual(self.system_1.data['coords'][ff][ii][jj] / tmp_cell_norm[ff][jj],
57-
self.system_2.data['coords'][ff][ii][jj] / tmp_cell_norm[ff][jj],
58-
places = self.places,
59-
msg = 'coord[%d][%d][%d] failed' % (ff,ii,jj))
49+
tmp_cell_norm = np.reshape(np.linalg.norm(tmp_cell, axis = 1), [-1, 1, 3])
50+
np.testing.assert_almost_equal(self.system_1.data['coords'] / tmp_cell_norm,
51+
self.system_2.data['coords'] / tmp_cell_norm,
52+
decimal = self.places,
53+
err_msg = 'coord failed')
6054

6155
def test_nopbc(self):
6256
self.assertEqual(self.system_1.nopbc, self.system_2.nopbc)
@@ -66,22 +60,18 @@ class CompLabeledSys (CompSys) :
6660
def test_energy(self) :
6761
self.assertEqual(self.system_1.get_nframes(),
6862
self.system_2.get_nframes())
69-
for ff in range(self.system_1.get_nframes()) :
70-
self.assertAlmostEqual(self.system_1.data['energies'][ff],
71-
self.system_2.data['energies'][ff],
72-
places = self.e_places,
73-
msg = 'energies[%d] failed' % (ff))
63+
np.testing.assert_almost_equal(self.system_1.data['energies'],
64+
self.system_2.data['energies'],
65+
decimal = self.e_places,
66+
err_msg = 'energies failed')
7467

7568
def test_force(self) :
7669
self.assertEqual(self.system_1.get_nframes(),
7770
self.system_2.get_nframes())
78-
for ff in range(self.system_1.get_nframes()) :
79-
for ii in range(self.system_1.data['forces'].shape[1]) :
80-
for jj in range(3) :
81-
self.assertAlmostEqual(self.system_1.data['forces'][ff][ii][jj],
82-
self.system_2.data['forces'][ff][ii][jj],
83-
places = self.f_places,
84-
msg = 'forces[%d][%d][%d] failed' % (ff,ii,jj))
71+
np.testing.assert_almost_equal(self.system_1.data['forces'],
72+
self.system_2.data['forces'],
73+
decimal = self.f_places,
74+
err_msg = 'forces failed')
8575

8676
def test_virial(self) :
8777
self.assertEqual(self.system_1.get_nframes(),
@@ -92,13 +82,10 @@ def test_virial(self) :
9282
if not 'virials' in self.system_1:
9383
self.assertFalse('virials' in self.system_2)
9484
return
95-
for ff in range(self.system_1.get_nframes()) :
96-
for ii in range(3) :
97-
for jj in range(3) :
98-
self.assertAlmostEqual(self.system_1['virials'][ff][ii][jj],
99-
self.system_2['virials'][ff][ii][jj],
100-
places = self.v_places,
101-
msg = 'virials[%d][%d][%d] failed' % (ff,ii,jj))
85+
np.testing.assert_almost_equal(self.system_1['virials'],
86+
self.system_2['virials'],
87+
decimal = self.v_places,
88+
err_msg = 'virials failed')
10289

10390

10491
class MultiSystems:

0 commit comments

Comments
 (0)