1
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
1
+ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2
2
#
3
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
4
# you may not use this file except in compliance with the License.
15
15
import unittest
16
16
17
17
import numpy as np
18
- from op_test import OpTest , skip_check_grad_ci
18
+ from op_test import OpTest
19
19
from utils import dygraph_guard , static_guard
20
20
21
21
import paddle
@@ -69,13 +69,10 @@ def init_data(self):
69
69
self .outputs = {"s" : self ._output_data }
70
70
71
71
72
- @skip_check_grad_ci (
73
- reason = "'check_grad' on singular values is not required for svdvals."
74
- )
75
72
class TestSvdvalsBigMatrix (TestSvdvalsOp ):
76
73
def init_data (self ):
77
74
"""Generate large input matrix."""
78
- self ._input_shape = (200 , 300 )
75
+ self ._input_shape = (40 , 40 )
79
76
self ._input_data = np .random .random (self ._input_shape ).astype ("float64" )
80
77
self ._output_data = np .linalg .svd (
81
78
self ._input_data , compute_uv = False , hermitian = False
@@ -84,7 +81,13 @@ def init_data(self):
84
81
self .outputs = {'s' : self ._output_data }
85
82
86
83
def test_check_grad (self ):
87
- pass
84
+ self .check_grad (
85
+ ['x' ],
86
+ ['s' ],
87
+ numeric_grad_delta = 0.001 ,
88
+ max_relative_error = 1e-5 ,
89
+ check_pir = True ,
90
+ )
88
91
89
92
90
93
class TestSvdvalsAPI (unittest .TestCase ):
@@ -103,8 +106,7 @@ def test_dygraph_api(self):
103
106
# Test dynamic graph for svdvals
104
107
s = paddle .linalg .svdvals (x )
105
108
np_s = np .linalg .svd (self .x_np , compute_uv = False , hermitian = False )
106
- self .assertTrue (np .allclose (np_s , s .numpy (), rtol = 1e-6 ))
107
-
109
+ np .testing .assert_allclose (np_s , s .numpy (), rtol = 1e-6 )
108
110
# Test with reshaped input
109
111
x_reshaped = x .reshape ([- 1 , 12 , 10 ])
110
112
s_reshaped = paddle .linalg .svdvals (x_reshaped )
@@ -114,8 +116,8 @@ def test_dygraph_api(self):
114
116
for matrix in self .x_np .reshape ([- 1 , 12 , 10 ])
115
117
]
116
118
)
117
- self . assertTrue (
118
- np . allclose ( np_s_reshaped , s_reshaped .numpy (), rtol = 1e-6 )
119
+ np . testing . assert_allclose (
120
+ np_s_reshaped , s_reshaped .numpy (), rtol = 1e-6
119
121
)
120
122
121
123
def test_static_api (self ):
@@ -130,7 +132,7 @@ def test_static_api(self):
130
132
131
133
np_s = np .linalg .svd (self .x_np , compute_uv = False , hermitian = False )
132
134
for r in res :
133
- self . assertTrue ( np .allclose (np_s , r , rtol = 1e-6 ) )
135
+ np .testing . assert_allclose (np_s , r , rtol = 1e-6 )
134
136
135
137
def test_error (self ):
136
138
"""Test invalid inputs for svdvals"""
0 commit comments