Skip to content

Commit 8464c04

Browse files
authored
【Paddle Tensor No.26】fix svdvals (#69820)
* fix svdvals * fix * fix * fix * fix bug in op_gen * fix svdvals and restore op_gen
1 parent 8c5e854 commit 8464c04

File tree

4 files changed

+17
-27
lines changed

4 files changed

+17
-27
lines changed

paddle/phi/kernels/cpu/svdvals_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ void SvdvalsKernel(const Context& dev_ctx,
107107
0,
108108
phi::errors::InvalidArgument("The batch size of Input(X) must be > 0."));
109109
DDim s_dims;
110-
if (batches == 1) {
110+
if (x_dims.size() <= 2) {
111111
s_dims = {k};
112112
} else {
113113
s_dims = {batches, k};

python/paddle/tensor/linalg.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3025,20 +3025,7 @@ def svdvals(x: Tensor, name: str | None = None) -> Tensor:
30253025
Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
30263026
[8.14753819, 0.78589684])
30273027
"""
3028-
if in_dynamic_or_pir_mode():
3029-
return _C_ops.svdvals(x)
3030-
else:
3031-
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'svdvals')
3032-
helper = LayerHelper('svdvals', **locals())
3033-
s = helper.create_variable_for_type_inference(dtype=x.dtype)
3034-
attrs = {}
3035-
helper.append_op(
3036-
type='svdvals',
3037-
inputs={'X': [x]},
3038-
outputs={'S': s},
3039-
attrs=attrs,
3040-
)
3041-
return s
3028+
return _C_ops.svdvals(x)
30423029

30433030

30443031
def _conjugate(x):

test/legacy_test/test_svdvals_op.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
1515
import unittest
1616

1717
import numpy as np
18-
from op_test import OpTest, skip_check_grad_ci
18+
from op_test import OpTest
1919
from utils import dygraph_guard, static_guard
2020

2121
import paddle
@@ -69,13 +69,10 @@ def init_data(self):
6969
self.outputs = {"s": self._output_data}
7070

7171

72-
@skip_check_grad_ci(
73-
reason="'check_grad' on singular values is not required for svdvals."
74-
)
7572
class TestSvdvalsBigMatrix(TestSvdvalsOp):
7673
def init_data(self):
7774
"""Generate large input matrix."""
78-
self._input_shape = (200, 300)
75+
self._input_shape = (40, 40)
7976
self._input_data = np.random.random(self._input_shape).astype("float64")
8077
self._output_data = np.linalg.svd(
8178
self._input_data, compute_uv=False, hermitian=False
@@ -84,7 +81,13 @@ def init_data(self):
8481
self.outputs = {'s': self._output_data}
8582

8683
def test_check_grad(self):
87-
pass
84+
self.check_grad(
85+
['x'],
86+
['s'],
87+
numeric_grad_delta=0.001,
88+
max_relative_error=1e-5,
89+
check_pir=True,
90+
)
8891

8992

9093
class TestSvdvalsAPI(unittest.TestCase):
@@ -103,8 +106,7 @@ def test_dygraph_api(self):
103106
# Test dynamic graph for svdvals
104107
s = paddle.linalg.svdvals(x)
105108
np_s = np.linalg.svd(self.x_np, compute_uv=False, hermitian=False)
106-
self.assertTrue(np.allclose(np_s, s.numpy(), rtol=1e-6))
107-
109+
np.testing.assert_allclose(np_s, s.numpy(), rtol=1e-6)
108110
# Test with reshaped input
109111
x_reshaped = x.reshape([-1, 12, 10])
110112
s_reshaped = paddle.linalg.svdvals(x_reshaped)
@@ -114,8 +116,8 @@ def test_dygraph_api(self):
114116
for matrix in self.x_np.reshape([-1, 12, 10])
115117
]
116118
)
117-
self.assertTrue(
118-
np.allclose(np_s_reshaped, s_reshaped.numpy(), rtol=1e-6)
119+
np.testing.assert_allclose(
120+
np_s_reshaped, s_reshaped.numpy(), rtol=1e-6
119121
)
120122

121123
def test_static_api(self):
@@ -130,7 +132,7 @@ def test_static_api(self):
130132

131133
np_s = np.linalg.svd(self.x_np, compute_uv=False, hermitian=False)
132134
for r in res:
133-
self.assertTrue(np.allclose(np_s, r, rtol=1e-6))
135+
np.testing.assert_allclose(np_s, r, rtol=1e-6)
134136

135137
def test_error(self):
136138
"""Test invalid inputs for svdvals"""

test/white_list/op_threshold_white_list.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
'lgamma',
4949
'sparse_attention',
5050
'svd',
51+
'svdvals',
5152
'matrix_power',
5253
'cholesky_solve',
5354
'solve',

0 commit comments

Comments
 (0)