Skip to content

Commit 2c253d5

Browse files
MAINT: Replace deprecated numpy assertions in coordinates/base.py (#5156)
* MAINT: Replace old numpy assertions --------- Co-authored-by: Egor Marin <me@marinegor.dev>
1 parent 21f7d80 commit 2c253d5

File tree

2 files changed

+86
-49
lines changed

2 files changed

+86
-49
lines changed

package/AUTHORS

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,8 @@ Chronological list of authors
270270
2026
271271
- Mohammad Ayaan
272272
- Khushi Phougat
273-
273+
- Kushagar Garg
274+
274275
External code
275276
-------------
276277

testsuite/MDAnalysisTests/coordinates/base.py

Lines changed: 84 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828
from unittest import TestCase
2929
from numpy.testing import (
3030
assert_equal,
31-
assert_almost_equal,
32-
assert_array_almost_equal,
3331
assert_allclose,
3432
)
3533

@@ -118,21 +116,23 @@ def test_dt(self):
118116
def test_coordinates(self):
119117
A10CA = self.universe.select_atoms("name CA")[10]
120118
# restrict accuracy to maximum in PDB files (3 decimals)
121-
assert_almost_equal(
119+
assert_allclose(
122120
A10CA.position,
123121
self.ref_coordinates["A10CA"],
124-
3,
122+
atol=10 ** (-self.prec),
123+
rtol=0,
125124
err_msg="wrong coordinates for A10:CA",
126125
)
127126

128127
def test_distances(self):
129128
NTERM = self.universe.select_atoms("name N")[0]
130129
CTERM = self.universe.select_atoms("name C")[-1]
131130
d = mda.lib.mdamath.norm(NTERM.position - CTERM.position)
132-
assert_almost_equal(
131+
assert_allclose(
133132
d,
134133
self.ref_distances["endtoend"],
135-
self.prec,
134+
atol=10 ** (-self.prec),
135+
rtol=0,
136136
err_msg="distance between M1:N and G214:C",
137137
)
138138

@@ -322,21 +322,26 @@ def test_get_writer_2(self, ref, reader, tmpdir):
322322
assert_equal(W.n_atoms, 100)
323323

324324
def test_dt(self, ref, reader):
325-
assert_almost_equal(reader.dt, ref.dt, decimal=ref.prec)
325+
assert_allclose(reader.dt, ref.dt, atol=10 ** (-ref.prec), rtol=0)
326326

327327
def test_ts_dt_matches_reader(self, reader):
328328
assert_equal(reader.ts.dt, reader.dt)
329329

330330
def test_total_time(self, ref, reader):
331-
assert_almost_equal(reader.totaltime, ref.totaltime, decimal=ref.prec)
331+
assert_allclose(
332+
reader.totaltime, ref.totaltime, atol=10 ** (-ref.prec), rtol=0
333+
)
332334

333335
def test_first_dimensions(self, ref, reader):
334336
reader.rewind()
335337
if ref.dimensions is None:
336338
assert reader.ts.dimensions is None
337339
else:
338-
assert_array_almost_equal(
339-
reader.ts.dimensions, ref.dimensions, decimal=ref.prec
340+
assert_allclose(
341+
reader.ts.dimensions,
342+
ref.dimensions,
343+
atol=10 ** (-ref.prec),
344+
rtol=0,
340345
)
341346

342347
def test_changing_dimensions(self, ref, reader):
@@ -345,25 +350,29 @@ def test_changing_dimensions(self, ref, reader):
345350
if ref.dimensions is None:
346351
assert reader.ts.dimensions is None
347352
else:
348-
assert_array_almost_equal(
349-
reader.ts.dimensions, ref.dimensions, decimal=ref.prec
353+
assert_allclose(
354+
reader.ts.dimensions,
355+
ref.dimensions,
356+
atol=10 ** (-ref.prec),
357+
rtol=0,
350358
)
351359
reader[1]
352360
if ref.dimensions_second_frame is None:
353361
assert reader.ts.dimensions is None
354362
else:
355-
assert_array_almost_equal(
363+
assert_allclose(
356364
reader.ts.dimensions,
357365
ref.dimensions_second_frame,
358-
decimal=ref.prec,
366+
atol=10 ** (-ref.prec),
367+
rtol=0,
359368
)
360369

361370
def test_volume(self, ref, reader):
362371
reader.rewind()
363372
vol = reader.ts.volume
364373
# Here we can only be sure about the numbers upto the decimal point due
365374
# to floating point impressions.
366-
assert_almost_equal(vol, ref.volume, 0)
375+
assert_allclose(vol, ref.volume, atol=1, rtol=0)
367376

368377
def test_iter(self, ref, reader):
369378
for i, ts in enumerate(reader):
@@ -387,9 +396,11 @@ def test_remove_nonexistant_auxiliary_raises_ValueError(self, reader):
387396
def test_iter_auxiliary(self, ref, reader):
388397
# should go through all steps in 'highf'
389398
for i, auxstep in enumerate(reader.iter_auxiliary("highf")):
390-
assert_almost_equal(
399+
assert_allclose(
391400
auxstep.data,
392401
ref.aux_highf_all_data[i],
402+
atol=1e-7,
403+
rtol=0,
393404
err_msg="Auxiliary data does not match for "
394405
"step {}".format(i),
395406
)
@@ -453,9 +464,7 @@ def test_transformations_iter(self, ref, transformed):
453464
v2 = np.float32((0, 0, 0.33))
454465
for i, ts in enumerate(transformed):
455466
idealcoords = ref.iter_ts(i).positions + v1 + v2
456-
assert_array_almost_equal(
457-
ts.positions, idealcoords, decimal=ref.prec
458-
)
467+
assert_allclose(ts.positions, idealcoords, atol=ref.prec, rtol=0)
459468

460469
def test_transformations_2iter(self, ref, transformed):
461470
# Are the transformations applied and
@@ -465,21 +474,23 @@ def test_transformations_2iter(self, ref, transformed):
465474
idealcoords = []
466475
for i, ts in enumerate(transformed):
467476
idealcoords.append(ref.iter_ts(i).positions + v1 + v2)
468-
assert_array_almost_equal(
469-
ts.positions, idealcoords[i], decimal=ref.prec
477+
assert_allclose(
478+
ts.positions, idealcoords[i], atol=10 ** (-ref.prec), rtol=0
470479
)
471480

472481
for i, ts in enumerate(transformed):
473-
assert_almost_equal(ts.positions, idealcoords[i], decimal=ref.prec)
482+
assert_allclose(
483+
ts.positions, idealcoords[i], atol=10 ** (-ref.prec), rtol=0
484+
)
474485

475486
def test_transformations_slice(self, ref, transformed):
476487
# Are the transformations applied when iterating over a slice of the trajectory?
477488
v1 = np.float32((1, 1, 1))
478489
v2 = np.float32((0, 0, 0.33))
479490
for i, ts in enumerate(transformed[2:3:1]):
480491
idealcoords = ref.iter_ts(ts.frame).positions + v1 + v2
481-
assert_array_almost_equal(
482-
ts.positions, idealcoords, decimal=ref.prec
492+
assert_allclose(
493+
ts.positions, idealcoords, atol=10 ** (-ref.prec), rtol=0
483494
)
484495

485496
def test_transformations_switch_frame(self, ref, transformed):
@@ -490,26 +501,41 @@ def test_transformations_switch_frame(self, ref, transformed):
490501
v2 = np.float32((0, 0, 0.33))
491502
first_ideal = ref.iter_ts(0).positions + v1 + v2
492503
if len(transformed) > 1:
493-
assert_array_almost_equal(
494-
transformed[0].positions, first_ideal, decimal=ref.prec
504+
assert_allclose(
505+
transformed[0].positions,
506+
first_ideal,
507+
atol=10 ** (-ref.prec),
508+
rtol=0,
495509
)
496510
second_ideal = ref.iter_ts(1).positions + v1 + v2
497-
assert_array_almost_equal(
498-
transformed[1].positions, second_ideal, decimal=ref.prec
511+
assert_allclose(
512+
transformed[1].positions,
513+
second_ideal,
514+
atol=10 ** (-ref.prec),
515+
rtol=0,
499516
)
500517

501518
# What if we comeback to the previous frame?
502-
assert_array_almost_equal(
503-
transformed[0].positions, first_ideal, decimal=ref.prec
519+
assert_allclose(
520+
transformed[0].positions,
521+
first_ideal,
522+
atol=10 ** (-ref.prec),
523+
rtol=0,
504524
)
505525

506526
# How about we switch the frame to itself?
507-
assert_array_almost_equal(
508-
transformed[0].positions, first_ideal, decimal=ref.prec
527+
assert_allclose(
528+
transformed[0].positions,
529+
first_ideal,
530+
atol=10 ** (-ref.prec),
531+
rtol=0,
509532
)
510533
else:
511-
assert_array_almost_equal(
512-
transformed[0].positions, first_ideal, decimal=ref.prec
534+
assert_allclose(
535+
transformed[0].positions,
536+
first_ideal,
537+
atol=10 ** (-ref.prec),
538+
rtol=0,
513539
)
514540

515541
def test_transformation_rewind(self, ref, transformed):
@@ -519,8 +545,11 @@ def test_transformation_rewind(self, ref, transformed):
519545
v2 = np.float32((0, 0, 0.33))
520546
ideal_coords = ref.iter_ts(0).positions + v1 + v2
521547
transformed.rewind()
522-
assert_array_almost_equal(
523-
transformed[0].positions, ideal_coords, decimal=ref.prec
548+
assert_allclose(
549+
transformed[0].positions,
550+
ideal_coords,
551+
atol=10 ** (-ref.prec),
552+
rtol=0,
524553
)
525554

526555
def test_transformations_copy(self, ref, transformed):
@@ -536,8 +565,8 @@ def test_transformations_copy(self, ref, transformed):
536565
)
537566
for i, ts in enumerate(new):
538567
ideal_coords = ref.iter_ts(i).positions + v1 + v2
539-
assert_array_almost_equal(
540-
ts.positions, ideal_coords, decimal=ref.prec
568+
assert_allclose(
569+
ts.positions, ideal_coords, atol=10 ** (-ref.prec), rtol=0
541570
)
542571

543572
def test_add_another_transformations_raises_ValueError(self, transformed):
@@ -812,8 +841,11 @@ def test_write_different_box(self, ref, universe, tmpdir):
812841

813842
for ts_ref, ts_w in zip(universe.trajectory, written):
814843
universe.dimensions[:3] += 1
815-
assert_array_almost_equal(
816-
universe.dimensions, ts_w.dimensions, decimal=ref.prec
844+
assert_allclose(
845+
universe.dimensions,
846+
ts_w.dimensions,
847+
atol=10 ** (-ref.prec),
848+
rtol=0,
817849
)
818850

819851
def test_write_trajectory_atomgroup(self, ref, reader, universe, tmpdir):
@@ -853,10 +885,11 @@ def test_write_selection(
853885

854886
copy = ref.reader(outfile)
855887
for orig_ts, copy_ts in zip(universe.trajectory, copy):
856-
assert_array_almost_equal(
888+
assert_allclose(
857889
copy_ts._pos,
858890
sel.atoms.positions,
859-
ref.prec,
891+
atol=10 ** (-ref.prec),
892+
rtol=0,
860893
err_msg="coordinate mismatch between original and written "
861894
"trajectory at frame {} (orig) vs {} (copy)".format(
862895
orig_ts.frame, copy_ts.frame
@@ -933,10 +966,11 @@ def assert_timestep_almost_equal(A, B, decimal=6, verbose=True):
933966
)
934967

935968
if A.has_positions:
936-
assert_array_almost_equal(
969+
assert_allclose(
937970
A.positions,
938971
B.positions,
939-
decimal=decimal,
972+
atol=10 ** (-decimal),
973+
rtol=0,
940974
err_msg="Timestep positions",
941975
verbose=verbose,
942976
)
@@ -949,10 +983,11 @@ def assert_timestep_almost_equal(A, B, decimal=6, verbose=True):
949983
)
950984
)
951985
if A.has_velocities:
952-
assert_array_almost_equal(
986+
assert_allclose(
953987
A.velocities,
954988
B.velocities,
955-
decimal=decimal,
989+
atol=10 ** (-decimal),
990+
rtol=0,
956991
err_msg="Timestep velocities",
957992
verbose=verbose,
958993
)
@@ -965,10 +1000,11 @@ def assert_timestep_almost_equal(A, B, decimal=6, verbose=True):
9651000
)
9661001
)
9671002
if A.has_forces:
968-
assert_array_almost_equal(
1003+
assert_allclose(
9691004
A.forces,
9701005
B.forces,
971-
decimal=decimal,
1006+
atol=10 ** (-decimal),
1007+
rtol=0,
9721008
err_msg="Timestep forces",
9731009
verbose=verbose,
9741010
)

0 commit comments

Comments
 (0)