Skip to content

Commit 850b737

Browse files
authored
Fix nparray.all() bug. (#16472)
1 parent 1b4e4e7 commit 850b737

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

python/paddle/fluid/tests/unittests/test_dist_save_load.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def check_with_place(self,
6868
train0_np = np.array(tr0_var)
6969
train1_np = np.array(tr1_var)
7070

71-
self.assertAlmostEqual(local_np.all(), train0_np.all(), delta=delta)
72-
self.assertAlmostEqual(local_np.all(), train1_np.all(), delta=delta)
73-
self.assertAlmostEqual(train0_np.all(), train1_np.all(), delta=delta)
71+
np.testing.assert_almost_equal(local_np, train0_np, decimal=2)
72+
np.testing.assert_almost_equal(local_np, train1_np, decimal=2)
73+
np.testing.assert_almost_equal(train0_np, train1_np, decimal=2)
7474

7575
def test_dist(self):
7676
need_envs = {
@@ -134,10 +134,8 @@ def check_with_place(self,
134134
train0_2_np = np.array(tr0_var_2)
135135
train1_2_np = np.array(tr1_var_2)
136136

137-
self.assertAlmostEqual(
138-
train0_1_np.all(), train0_2_np.all(), delta=delta)
139-
self.assertAlmostEqual(
140-
train1_1_np.all(), train1_2_np.all(), delta=delta)
137+
np.testing.assert_almost_equal(train0_1_np, train0_2_np, decimal=2)
138+
np.testing.assert_almost_equal(train1_1_np, train1_2_np, decimal=2)
141139

142140
def test_dist(self):
143141
need_envs = {

python/paddle/fluid/tests/unittests/test_nce_remote_table_op.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,9 @@ def _run_nce_op_two_pserver(self, place, port0, port1):
205205
out = nce(x_array, param_array, bias_array, sample_weight,
206206
label_array, 5, 2)
207207

208-
self.assertAlmostEqual(o_cost.all(), out[0].all(), delta=1e-6)
209-
self.assertAlmostEqual(o_logits.all(), out[1].all(), delta=1e-6)
210-
self.assertAlmostEqual(o_labels.all(), out[2].all(), delta=1e-6)
208+
np.testing.assert_almost_equal(o_cost, out[0], decimal=6)
209+
np.testing.assert_almost_equal(o_logits, out[1], decimal=6)
210+
np.testing.assert_almost_equal(o_labels, out[2], decimal=6)
211211

212212
def test_nce_op_remote(self):
213213
os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1"

0 commit comments

Comments
 (0)