Skip to content

Commit 8cf52fd

Browse files
authored
use np.testing.assert_almost_equal for array comparing (#1059)
* use `np.testing.assert_almost_equal` for array comparing * bugfix * bugfix * bugfix * bugfix * bugfix * bugfix * bugfix
1 parent 37fe14c commit 8cf52fd

39 files changed

+275
-558
lines changed

source/tests/common.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -258,11 +258,9 @@ def virial_test (inter,
258258
num_vir = np.transpose(num_vir, [1,0])
259259
box3 = dbox[0].reshape([3,3])
260260
num_vir = np.matmul(num_vir, box3)
261-
for ii in range(3):
262-
for jj in range(3):
263-
testCase.assertAlmostEqual(ana_vir[ii][jj], num_vir[ii][jj],
264-
places=places,
265-
msg = 'virial component %d %d ' % (ii,jj))
261+
np.testing.assert_almost_equal(ana_vir, num_vir,
262+
places,
263+
err_msg = 'virial component')
266264

267265

268266

source/tests/test_data_modifier.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,9 @@ def _test_fv (self):
119119
ep, _, __ = dcm.eval(coordp, box, atype, eval_fv = False)
120120
em, _, __ = dcm.eval(coordm, box, atype, eval_fv = False)
121121
num_f = -(ep - em) / (2.*hh)
122-
for ff in range(nframes):
123-
self.assertAlmostEqual(vf[ff,ii], num_f[ff],
124-
places = places,
125-
msg = 'frame %d dof %d does not match' % (ff, ii))
122+
np.testing.assert_almost_equal(vf[:,ii].ravel(), num_f.ravel(),
123+
places,
124+
err_msg = 'dof %d does not match' % (ii))
126125

127126
box3 = np.reshape(box, [nframes, 3,3])
128127
rbox3 = np.linalg.inv(box3)
@@ -150,10 +149,7 @@ def _test_fv (self):
150149
t_esti = np.matmul(num_deriv, box3)
151150

152151
# print(t_esti, '\n', vv.reshape([-1, 3, 3]))
153-
for ff in range(nframes):
154-
for ii in range(3):
155-
for jj in range(3):
156-
self.assertAlmostEqual(t_esti[ff][ii][jj], vv[ff,ii*3+jj],
157-
places = places,
158-
msg = "frame %d virial component [%d,%d] failed" % (ff, ii, jj))
152+
np.testing.assert_almost_equal(t_esti.ravel(), vv.ravel(),
153+
places,
154+
err_msg = "virial component failed")
159155

source/tests/test_data_modifier_shuffle.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,9 @@ def test_z_dipole(self):
185185
dv01 = dv01.reshape([self.nframes, -1])
186186
dv1 = dv1.reshape([self.nframes, -1])
187187

188-
for ii in range(self.nframes):
189-
for jj in range(self.nsel):
190-
self.assertAlmostEqual(
191-
dv01[ii][jj], dv1[ii][jj],
192-
msg = "dipole [%d,%d] dose not match" % (ii, jj))
188+
np.testing.assert_almost_equal(
189+
dv01, dv1,
190+
err_msg = "dipole dose not match")
193191

194192

195193
def test_modify(self):
@@ -202,18 +200,12 @@ def test_modify(self):
202200
ve1, vf1, vv1 = dcm.eval(self.coords1, self.box1, self.atom_types1)
203201
vf01 = vf0[:,self.idx_map, :]
204202

205-
for ii in range(self.nframes):
206-
self.assertAlmostEqual(ve0[ii], ve1[ii],
207-
msg = 'energy %d should match' % ii)
208-
for ii in range(self.nframes):
209-
for jj in range(9):
210-
self.assertAlmostEqual(vv0[ii][jj], vv1[ii][jj],
211-
msg = 'virial [%d,%d] should match' % (ii,jj))
212-
for ii in range(self.nframes):
213-
for jj in range(self.natoms):
214-
for dd in range(3):
215-
self.assertAlmostEqual(
216-
vf01[ii][jj][dd], vf1[ii][jj][dd],
217-
msg = "force [%d,%d,%d] dose not match" % (ii,jj,dd))
203+
np.testing.assert_almost_equal(ve0, ve1,
204+
err_msg = 'energy should match')
205+
np.testing.assert_almost_equal(vv0, vv1,
206+
err_msg = 'virial should match')
207+
np.testing.assert_almost_equal(
208+
vf01, vf1,
209+
err_msg = "force dose not match")
218210

219211

source/tests/test_deepdipole.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def test_1frame_atm(self):
4444
nsel = 2
4545
self.assertEqual(dd.shape, (nframes,nsel,3))
4646
# check values
47-
for ii in range(dd.size):
48-
self.assertAlmostEqual(dd.reshape([-1])[ii], self.expected_d.reshape([-1])[ii], places = default_places)
47+
np.testing.assert_almost_equal(dd.ravel(), self.expected_d, default_places)
4948

5049
def test_2frame_atm(self):
5150
coords2 = np.concatenate((self.coords, self.coords))
@@ -58,8 +57,7 @@ def test_2frame_atm(self):
5857
self.assertEqual(dd.shape, (nframes,nsel,3))
5958
# check values
6059
expected_d = np.concatenate((self.expected_d, self.expected_d))
61-
for ii in range(dd.size):
62-
self.assertAlmostEqual(dd.reshape([-1])[ii], expected_d.reshape([-1])[ii], places = default_places)
60+
np.testing.assert_almost_equal(dd.ravel(), expected_d, default_places)
6361

6462

6563
class TestDeepDipoleNoPBC(unittest.TestCase) :
@@ -87,8 +85,7 @@ def test_1frame_atm(self):
8785
nsel = 2
8886
self.assertEqual(dd.shape, (nframes,nsel,3))
8987
# check values
90-
for ii in range(dd.size):
91-
self.assertAlmostEqual(dd.reshape([-1])[ii], self.expected_d.reshape([-1])[ii], places = default_places)
88+
np.testing.assert_almost_equal(dd.ravel(), self.expected_d, default_places)
9289

9390
def test_1frame_atm_large_box(self):
9491
dd = self.dp.eval(self.coords, self.box, self.atype)
@@ -98,8 +95,7 @@ def test_1frame_atm_large_box(self):
9895
nsel = 2
9996
self.assertEqual(dd.shape, (nframes,nsel,3))
10097
# check values
101-
for ii in range(dd.size):
102-
self.assertAlmostEqual(dd.reshape([-1])[ii], self.expected_d.reshape([-1])[ii], places = default_places)
98+
np.testing.assert_almost_equal(dd.ravel(), self.expected_d, default_places)
10399

104100

105101
@unittest.skipIf(parse_version(tf.__version__) < parse_version("1.15"),
@@ -138,8 +134,7 @@ def test_1frame_old(self):
138134
nframes = 1
139135
self.assertEqual(gt.shape, (nframes,self.nout))
140136
# check values
141-
for ii in range(gt.size):
142-
self.assertAlmostEqual(gt.reshape([-1])[ii], self.expected_gt.reshape([-1])[ii], places = default_places)
137+
np.testing.assert_almost_equal(gt.ravel(), self.expected_gt, default_places)
143138

144139
def test_1frame_old_atm(self):
145140
at = self.dp.eval(self.coords, self.box, self.atype)
@@ -149,8 +144,7 @@ def test_1frame_old_atm(self):
149144
nsel = 2
150145
self.assertEqual(at.shape, (nframes,nsel,self.nout))
151146
# check values
152-
for ii in range(at.size):
153-
self.assertAlmostEqual(at.reshape([-1])[ii], self.expected_t.reshape([-1])[ii], places = default_places)
147+
np.testing.assert_almost_equal(at.ravel(), self.expected_t, default_places)
154148

155149
def test_2frame_old_atm(self):
156150
coords2 = np.concatenate((self.coords, self.coords))
@@ -163,8 +157,7 @@ def test_2frame_old_atm(self):
163157
self.assertEqual(at.shape, (nframes,nsel,self.nout))
164158
# check values
165159
expected_d = np.concatenate((self.expected_t, self.expected_t))
166-
for ii in range(at.size):
167-
self.assertAlmostEqual(at.reshape([-1])[ii], expected_d.reshape([-1])[ii], places = default_places)
160+
np.testing.assert_almost_equal(at.ravel(), expected_d, default_places)
168161

169162
def test_1frame_full(self):
170163
gt, ff, vv = self.dp.eval_full(self.coords, self.box, self.atype, atomic = False)
@@ -175,12 +168,9 @@ def test_1frame_full(self):
175168
self.assertEqual(ff.shape, (nframes,self.nout,natoms,3))
176169
self.assertEqual(vv.shape, (nframes,self.nout,9))
177170
# check values
178-
for ii in range(ff.size):
179-
self.assertAlmostEqual(ff.reshape([-1])[ii], self.expected_f.reshape([-1])[ii], places = default_places)
180-
for ii in range(gt.size):
181-
self.assertAlmostEqual(gt.reshape([-1])[ii], self.expected_gt.reshape([-1])[ii], places = default_places)
182-
for ii in range(vv.size):
183-
self.assertAlmostEqual(vv.reshape([-1])[ii], self.expected_gv.reshape([-1])[ii], places = default_places)
171+
np.testing.assert_almost_equal(ff.ravel(), self.expected_f, default_places)
172+
np.testing.assert_almost_equal(gt.ravel(), self.expected_gt, default_places)
173+
np.testing.assert_almost_equal(vv.ravel(), self.expected_gv, default_places)
184174

185175
def test_1frame_full_atm(self):
186176
gt, ff, vv, at, av = self.dp.eval_full(self.coords, self.box, self.atype, atomic = True)

source/tests/test_deepmd_data.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,15 @@ def test_load_set_1(self) :
4949
.add('value_1', 1, atomic=True, must=True, type_sel = [0])
5050
data = dd._load_set(os.path.join(self.data_name, 'set.foo'))
5151
self.assertEqual(data['value_1'].shape, (self.nframes, 2))
52-
for ii in range(self.nframes):
53-
for jj in range(2):
54-
self.assertAlmostEqual(data['value_1'][ii][jj],
55-
self.value_1[ii][jj])
52+
np.testing.assert_almost_equal(data['value_1'], self.value_1)
5653

5754

5855
def test_load_set_2(self) :
5956
dd = DeepmdData(self.data_name)\
6057
.add('value_2', 1, atomic=True, must=True, type_sel = [1])
6158
data = dd._load_set(os.path.join(self.data_name, 'set.foo'))
6259
self.assertEqual(data['value_2'].shape, (self.nframes, 4))
63-
for ii in range(self.nframes):
64-
for jj in range(4):
65-
self.assertAlmostEqual(data['value_2'][ii][jj],
66-
self.value_2[ii][jj])
60+
np.testing.assert_almost_equal(data['value_2'], self.value_2)
6761

6862

6963
class TestData (unittest.TestCase) :
@@ -217,8 +211,7 @@ def test_avg(self) :
217211
.add('test_frame', 5, atomic=False, must=True)
218212
favg = dd.avg('test_frame')
219213
fcmp = np.average(np.concatenate((self.test_frame, self.test_frame_bar), axis = 0), axis = 0)
220-
for ii in range(favg.size) :
221-
self.assertAlmostEqual((favg[ii]), (fcmp[ii]), places = places)
214+
np.testing.assert_almost_equal(favg, fcmp, places)
222215

223216
def test_check_batch_size(self) :
224217
dd = DeepmdData(self.data_name)
@@ -263,8 +256,4 @@ def test_get_nbatch(self):
263256
self.assertEqual(nb, 2)
264257

265258
def _comp_np_mat2(self, first, second) :
266-
for ii in range(first.shape[0]) :
267-
for jj in range(first.shape[1]) :
268-
self.assertAlmostEqual(first[ii][jj], second[ii][jj],
269-
msg = 'item [%d][%d] does not match' % (ii,jj),
270-
places = places)
259+
np.testing.assert_almost_equal(first, second, places)

source/tests/test_deeppolar.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def test_1frame_atm(self):
4444
nsel = 2
4545
self.assertEqual(dd.shape, (nframes,nsel,9))
4646
# check values
47-
for ii in range(dd.size):
48-
self.assertAlmostEqual(dd.reshape([-1])[ii], self.expected_d.reshape([-1])[ii], places = default_places)
47+
np.testing.assert_almost_equal(dd.ravel(), self.expected_d, default_places)
4948

5049
def test_2frame_atm(self):
5150
coords2 = np.concatenate((self.coords, self.coords))
@@ -58,8 +57,8 @@ def test_2frame_atm(self):
5857
self.assertEqual(dd.shape, (nframes,nsel,9))
5958
# check values
6059
expected_d = np.concatenate((self.expected_d, self.expected_d))
61-
for ii in range(dd.size):
62-
self.assertAlmostEqual(dd.reshape([-1])[ii], expected_d.reshape([-1])[ii], places = default_places)
60+
np.testing.assert_almost_equal(dd.ravel(), expected_d, default_places)
61+
6362

6463

6564
class TestDeepPolarNoPBC(unittest.TestCase) :
@@ -87,8 +86,7 @@ def test_1frame_atm(self):
8786
nsel = 2
8887
self.assertEqual(dd.shape, (nframes,nsel,9))
8988
# check values
90-
for ii in range(dd.size):
91-
self.assertAlmostEqual(dd.reshape([-1])[ii], self.expected_d.reshape([-1])[ii], places = default_places)
89+
np.testing.assert_almost_equal(dd.ravel(), self.expected_d, default_places)
9290

9391
def test_1frame_atm_large_box(self):
9492
dd = self.dp.eval(self.coords, self.box, self.atype)
@@ -98,8 +96,7 @@ def test_1frame_atm_large_box(self):
9896
nsel = 2
9997
self.assertEqual(dd.shape, (nframes,nsel,9))
10098
# check values
101-
for ii in range(dd.size):
102-
self.assertAlmostEqual(dd.reshape([-1])[ii], self.expected_d.reshape([-1])[ii], places = default_places)
99+
np.testing.assert_almost_equal(dd.ravel(), self.expected_d, default_places)
103100

104101

105102
@unittest.skipIf(parse_version(tf.__version__) < parse_version("1.15"),
@@ -138,8 +135,7 @@ def test_1frame_old(self):
138135
nframes = 1
139136
self.assertEqual(gt.shape, (nframes,self.nout))
140137
# check values
141-
for ii in range(gt.size):
142-
self.assertAlmostEqual(gt.reshape([-1])[ii], self.expected_gt.reshape([-1])[ii], places = default_places)
138+
np.testing.assert_almost_equal(gt.ravel(), self.expected_gt, default_places)
143139

144140
def test_1frame_old_atm(self):
145141
at = self.dp.eval(self.coords, self.box, self.atype)
@@ -149,8 +145,7 @@ def test_1frame_old_atm(self):
149145
nsel = 2
150146
self.assertEqual(at.shape, (nframes,nsel,self.nout))
151147
# check values
152-
for ii in range(at.size):
153-
self.assertAlmostEqual(at.reshape([-1])[ii], self.expected_t.reshape([-1])[ii], places = default_places)
148+
np.testing.assert_almost_equal(at.ravel(), self.expected_t, default_places)
154149

155150
def test_2frame_old_atm(self):
156151
coords2 = np.concatenate((self.coords, self.coords))
@@ -163,8 +158,7 @@ def test_2frame_old_atm(self):
163158
self.assertEqual(at.shape, (nframes,nsel,self.nout))
164159
# check values
165160
expected_d = np.concatenate((self.expected_t, self.expected_t))
166-
for ii in range(at.size):
167-
self.assertAlmostEqual(at.reshape([-1])[ii], expected_d.reshape([-1])[ii], places = default_places)
161+
np.testing.assert_almost_equal(at.ravel(), expected_d, default_places)
168162

169163
def test_1frame_full(self):
170164
gt, ff, vv = self.dp.eval_full(self.coords, self.box, self.atype, atomic = False)
@@ -175,12 +169,9 @@ def test_1frame_full(self):
175169
self.assertEqual(ff.shape, (nframes,self.nout,natoms,3))
176170
self.assertEqual(vv.shape, (nframes,self.nout,9))
177171
# check values
178-
for ii in range(ff.size):
179-
self.assertAlmostEqual(ff.reshape([-1])[ii], self.expected_f.reshape([-1])[ii], places = default_places)
180-
for ii in range(gt.size):
181-
self.assertAlmostEqual(gt.reshape([-1])[ii], self.expected_gt.reshape([-1])[ii], places = default_places)
182-
for ii in range(vv.size):
183-
self.assertAlmostEqual(vv.reshape([-1])[ii], self.expected_gv.reshape([-1])[ii], places = default_places)
172+
np.testing.assert_almost_equal(ff.ravel(), self.expected_f, default_places)
173+
np.testing.assert_almost_equal(gt.ravel(), self.expected_gt, default_places)
174+
np.testing.assert_almost_equal(vv.ravel(), self.expected_gv, default_places)
184175

185176
def test_1frame_full_atm(self):
186177
gt, ff, vv, at, av = self.dp.eval_full(self.coords, self.box, self.atype, atomic = True)

0 commit comments

Comments
 (0)