1111# limitations under the License.
1212
1313import numpy as np
14- from sklearn .utils .extmath import randomized_svd
15- from sklearn .utils import check_array
14+ from tensorly import partial_svd
1615
1716F32PREC = np .finfo (np .float32 ).eps
1817
@@ -73,7 +72,7 @@ def fill(
7372 inplace : bool
7473 Modify matrix or fill a copy
7574 """
76- X = check_array (X , force_all_finite = False )
75+ # X = check_array(X, force_all_finite=False)
7776
7877 if not inplace :
7978 X = X .copy ()
@@ -99,7 +98,7 @@ def prepare_input_data(self, X):
9998 Check to make sure that the input matrix and its mask of missing
10099 values are valid. Returns X and missing mask.
101100 """
102- X = check_array (X , force_all_finite = False )
101+ # X = check_array(X, force_all_finite=False)
103102 if X .dtype != "f" and X .dtype != "d" :
104103 X = X .astype (float )
105104
@@ -251,11 +250,7 @@ def _svd_step(self, X, shrinkage_value, max_rank=None):
251250 """
252251 if max_rank :
253252 # if we have a max rank then perform the faster randomized SVD
254- (U , s , V ) = randomized_svd (
255- X ,
256- max_rank ,
257- n_iter = self .n_power_iterations ,
258- random_state = None )
253+ U , s , V = partial_svd (X , max_rank )
259254 else :
260255 # perform a full rank SVD using ARPACK
261256 (U , s , V ) = np .linalg .svd (
@@ -273,15 +268,11 @@ def _svd_step(self, X, shrinkage_value, max_rank=None):
273268
274269 def _max_singular_value (self , X_filled ):
275270 # quick decomposition of X_filled into rank-1 SVD
276- _ , s , _ = randomized_svd (
277- X_filled ,
278- 1 ,
279- n_iter = 5 ,
280- random_state = None )
271+ _ , s , _ = partial_svd (X_filled , 1 )
281272 return s [0 ]
282273
283274 def solve (self , X , missing_mask ):
284- X = check_array (X , force_all_finite = False )
275+ # X = check_array(X, force_all_finite=False)
285276
286277 X_init = X .copy ()
287278
0 commit comments