Skip to content

Commit 8b818d0

Browse files
[Cherry-Pick] fix sync batch norm op under cuda12 (#54641)
* Fix bug of test_sync_batch_norm_op_static_build accuracy problem under cuda12. * Remove useless code modification. * Remove useless code modification.
1 parent 57d9b80 commit 8b818d0

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

test/legacy_test/test_sync_batch_norm_op.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def setUp(self):
110110
self.H = 32
111111
self.W = 32
112112
self.dshape = [self.N, self.C, self.H, self.W]
113-
self.atol = 1e-3
113+
self.atol = 5e-3
114114
self.data_dir = tempfile.TemporaryDirectory()
115115
self.fleet_log_dir = tempfile.TemporaryDirectory()
116116

@@ -296,7 +296,7 @@ def _compare_impl(self, place, layout, only_forward):
296296
np.testing.assert_allclose(
297297
convert_numpy_array(bn_val),
298298
convert_numpy_array(sync_bn_val),
299-
rtol=1e-05,
299+
rtol=1e-04,
300300
atol=self.atol,
301301
err_msg='Output ('
302302
+ fetch_names[i]
@@ -340,7 +340,7 @@ def setUp(self):
340340
self.H = 32
341341
self.W = 32
342342
self.dshape = [self.N, self.C, self.H, self.W]
343-
self.atol = 1e-3
343+
self.atol = 5e-3
344344
self.data_dir = tempfile.TemporaryDirectory()
345345
self.fleet_log_dir = tempfile.TemporaryDirectory()
346346

0 commit comments

Comments
 (0)