Skip to content

Commit 627e5bd

Browse files
authored
Adjust the relative error of QR's grad (#44785)
* Adjust the relative error of QR's grad (#42221) * Fix the format
1 parent cd59df5 commit 627e5bd

File tree

2 files changed

+28
-17
lines changed

2 files changed

+28
-17
lines changed

python/paddle/fluid/tests/unittests/test_qr_op.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525

2626

2727
class TestQrOp(OpTest):
28+
2829
def setUp(self):
2930
paddle.enable_static()
30-
np.random.seed(4)
31+
np.random.seed(7)
3132
self.op_type = "qr"
3233
a, q, r = self.get_input_and_output()
3334
self.inputs = {"X": a}
@@ -74,30 +75,37 @@ def test_check_output(self):
7475
self.check_output()
7576

7677
def test_check_grad_normal(self):
77-
self.check_grad(['X'], ['Q', 'R'])
78+
self.check_grad(['X'], ['Q', 'R'],
79+
numeric_grad_delta=1e-5,
80+
max_relative_error=1e-6)
7881

7982

8083
class TestQrOpCase1(TestQrOp):
84+
8185
def get_shape(self):
8286
return (10, 12)
8387

8488

8589
class TestQrOpCase2(TestQrOp):
90+
8691
def get_shape(self):
8792
return (16, 15)
8893

8994

9095
class TestQrOpCase3(TestQrOp):
96+
9197
def get_shape(self):
9298
return (2, 12, 16)
9399

94100

95101
class TestQrOpCase4(TestQrOp):
102+
96103
def get_shape(self):
97104
return (3, 16, 15)
98105

99106

100107
class TestQrOpCase5(TestQrOp):
108+
101109
def get_mode(self):
102110
return "complete"
103111

@@ -106,6 +114,7 @@ def get_shape(self):
106114

107115

108116
class TestQrOpCase6(TestQrOp):
117+
109118
def get_mode(self):
110119
return "complete"
111120

@@ -114,8 +123,10 @@ def get_shape(self):
114123

115124

116125
class TestQrAPI(unittest.TestCase):
126+
117127
def test_dygraph(self):
118128
paddle.disable_static()
129+
np.random.seed(7)
119130

120131
def run_qr_dygraph(shape, mode, dtype):
121132
if dtype == "float32":
@@ -174,12 +185,13 @@ def run_qr_dygraph(shape, mode, dtype):
174185
]
175186
modes = ["reduced", "complete", "r"]
176187
dtypes = ["float32", "float64"]
177-
for tensor_shape, mode, dtype in itertools.product(tensor_shapes, modes,
178-
dtypes):
188+
for tensor_shape, mode, dtype in itertools.product(
189+
tensor_shapes, modes, dtypes):
179190
run_qr_dygraph(tensor_shape, mode, dtype)
180191

181192
def test_static(self):
182193
paddle.enable_static()
194+
np.random.seed(7)
183195

184196
def run_qr_static(shape, mode, dtype):
185197
if dtype == "float32":
@@ -216,29 +228,27 @@ def run_qr_static(shape, mode, dtype):
216228
tmp_q, tmp_r = np.linalg.qr(a[coord], mode=mode)
217229
np_q[coord] = tmp_q
218230
np_r[coord] = tmp_r
219-
x = paddle.fluid.data(
220-
name="input", shape=shape, dtype=dtype)
231+
x = paddle.fluid.data(name="input",
232+
shape=shape,
233+
dtype=dtype)
221234
if mode == "r":
222235
r = paddle.linalg.qr(x, mode=mode)
223236
exe = fluid.Executor(place)
224237
fetches = exe.run(fluid.default_main_program(),
225238
feed={"input": a},
226239
fetch_list=[r])
227-
self.assertTrue(
228-
np.allclose(
229-
fetches[0], np_r, atol=1e-5))
240+
self.assertTrue(np.allclose(fetches[0], np_r,
241+
atol=1e-5))
230242
else:
231243
q, r = paddle.linalg.qr(x, mode=mode)
232244
exe = fluid.Executor(place)
233245
fetches = exe.run(fluid.default_main_program(),
234246
feed={"input": a},
235247
fetch_list=[q, r])
236-
self.assertTrue(
237-
np.allclose(
238-
fetches[0], np_q, atol=1e-5))
239-
self.assertTrue(
240-
np.allclose(
241-
fetches[1], np_r, atol=1e-5))
248+
self.assertTrue(np.allclose(fetches[0], np_q,
249+
atol=1e-5))
250+
self.assertTrue(np.allclose(fetches[1], np_r,
251+
atol=1e-5))
242252

243253
tensor_shapes = [
244254
(3, 5),
@@ -253,8 +263,8 @@ def run_qr_static(shape, mode, dtype):
253263
]
254264
modes = ["reduced", "complete", "r"]
255265
dtypes = ["float32", "float64"]
256-
for tensor_shape, mode, dtype in itertools.product(tensor_shapes, modes,
257-
dtypes):
266+
for tensor_shape, mode, dtype in itertools.product(
267+
tensor_shapes, modes, dtypes):
258268
run_qr_static(tensor_shape, mode, dtype)
259269

260270

python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
'matrix_power', \
5252
'cholesky_solve', \
5353
'solve', \
54+
'qr', \
5455
]
5556

5657
NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\

0 commit comments

Comments
 (0)