@@ -96,7 +96,7 @@ def fista(
9696
9797
9898@nb .njit (fastmath = True )
99- def fista_cv (
99+ def fista_cv_nb (
100100 matrix : np .ndarray ,
101101 s : np .ndarray ,
102102 matrix_test : np .ndarray ,
@@ -112,9 +112,7 @@ def fista_cv(
112112 n_targets = s .shape [1 ]
113113 n_features = matrix .shape [1 ]
114114 prediction_error = np .zeros ((n_lambda , n_fold ))
115- iter_arr = np .zeros (n_lambda , dtype = int )
116- cv = np .zeros (n_lambda )
117- cvstd = np .zeros (n_lambda )
115+ iter_arr = np .zeros (n_lambda )
118116
119117 residue = np .zeros (max_iter )
120118 data_consistency = np .zeros (max_iter )
@@ -124,12 +122,12 @@ def fista_cv(
124122 y_train = s [..., fold ]
125123 x_test = matrix_test [..., fold ]
126124 y_test = s_test [..., fold ]
127- y_points = np . prod ( y_test .shape )
125+ y_points = y_test . shape [ 0 ] * y_test .shape [ 1 ]
128126
129127 gradient = x_train .T @ x_train
130128 c = x_train .T @ y_train
131129
132- norm_factor = np . linalg . norm (y_train ) ** 2
130+ norm_factor = norm (y_train ) ** 2
133131 f_k = np .zeros ((n_features , n_targets ))
134132 y_k = f_k .copy ()
135133
@@ -157,7 +155,7 @@ def fista_cv(
157155 else :
158156 f_k [:] = l1_soft_threshold (temp_c , l_inv * lam )
159157
160- residue [k ] = np . linalg . norm (x_train @ f_k - y_train ) ** 2
158+ residue [k ] = norm (x_train @ f_k - y_train ) ** 2
161159 fk_l1 = np .sum (np .abs (f_k ))
162160 data_consistency [k ] = residue [k ] + lam * fk_l1
163161
@@ -182,7 +180,24 @@ def fista_cv(
182180 prediction_error [j , fold ] = err / y_points
183181 iter_arr [j ] = k
184182
183+ return prediction_error , iter_arr
184+
185+
186+ def fista_cv (
187+ matrix : np .ndarray ,
188+ s : np .ndarray ,
189+ matrix_test : np .ndarray ,
190+ s_test : np .ndarray ,
191+ max_iter : int ,
192+ lambda_vals : np .ndarray ,
193+ nonnegative : bool ,
194+ l_inv : float ,
195+ tol : float ,
196+ ):
197+ prediction_error , iter_arr = fista_cv_nb (
198+ matrix , s , matrix_test , s_test , max_iter , lambda_vals , nonnegative , l_inv , tol
199+ )
200+
185201 cv = prediction_error .mean (axis = 1 )
186202 cvstd = prediction_error .std (axis = 1 )
187-
188203 return cv , cvstd , prediction_error , iter_arr
0 commit comments