Skip to content

Commit ea7c979

Browse files
committed
lint and fix numba fn
1 parent 94b711e commit ea7c979

File tree

4 files changed

+28
-700
lines changed

4 files changed

+28
-700
lines changed

mrinversion/linear_model/_base_l1l2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -523,18 +523,18 @@ def cross_validation_curve(self):
523523

524524
def cv(l1, X, y, cv):
525525
"""Return the cross-validation score as negative of mean square error."""
526-
if isinstance(l1, (Lasso, MultiTaskLasso)):
527-
fit_params = {"check_input": False}
528-
if isinstance(l1, LassoLars):
529-
fit_params = None # {"Xy": np.dot(X.T, y)}
526+
# if isinstance(l1, (Lasso, MultiTaskLasso)):
527+
# fit_params = {"check_input": False}
528+
# if isinstance(l1, LassoLars):
529+
# fit_params = None # {"Xy": np.dot(X.T, y)}
530530

531531
cv_score = cross_validate(
532532
l1,
533533
X=X,
534534
y=y,
535535
scoring="neg_mean_squared_error", # 'neg_mean_absolute_error",
536536
cv=cv,
537-
fit_params=fit_params,
537+
# fit_params=fit_params,
538538
n_jobs=1,
539539
verbose=0,
540540
)

mrinversion/linear_model/fista/__init__.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def fista(
9696

9797

9898
@nb.njit(fastmath=True)
99-
def fista_cv(
99+
def fista_cv_nb(
100100
matrix: np.ndarray,
101101
s: np.ndarray,
102102
matrix_test: np.ndarray,
@@ -112,9 +112,7 @@ def fista_cv(
112112
n_targets = s.shape[1]
113113
n_features = matrix.shape[1]
114114
prediction_error = np.zeros((n_lambda, n_fold))
115-
iter_arr = np.zeros(n_lambda, dtype=int)
116-
cv = np.zeros(n_lambda)
117-
cvstd = np.zeros(n_lambda)
115+
iter_arr = np.zeros(n_lambda)
118116

119117
residue = np.zeros(max_iter)
120118
data_consistency = np.zeros(max_iter)
@@ -124,12 +122,12 @@ def fista_cv(
124122
y_train = s[..., fold]
125123
x_test = matrix_test[..., fold]
126124
y_test = s_test[..., fold]
127-
y_points = np.prod(y_test.shape)
125+
y_points = y_test.shape[0] * y_test.shape[1]
128126

129127
gradient = x_train.T @ x_train
130128
c = x_train.T @ y_train
131129

132-
norm_factor = np.linalg.norm(y_train) ** 2
130+
norm_factor = norm(y_train) ** 2
133131
f_k = np.zeros((n_features, n_targets))
134132
y_k = f_k.copy()
135133

@@ -157,7 +155,7 @@ def fista_cv(
157155
else:
158156
f_k[:] = l1_soft_threshold(temp_c, l_inv * lam)
159157

160-
residue[k] = np.linalg.norm(x_train @ f_k - y_train) ** 2
158+
residue[k] = norm(x_train @ f_k - y_train) ** 2
161159
fk_l1 = np.sum(np.abs(f_k))
162160
data_consistency[k] = residue[k] + lam * fk_l1
163161

@@ -182,7 +180,24 @@ def fista_cv(
182180
prediction_error[j, fold] = err / y_points
183181
iter_arr[j] = k
184182

183+
return prediction_error, iter_arr
184+
185+
186+
def fista_cv(
187+
matrix: np.ndarray,
188+
s: np.ndarray,
189+
matrix_test: np.ndarray,
190+
s_test: np.ndarray,
191+
max_iter: int,
192+
lambda_vals: np.ndarray,
193+
nonnegative: bool,
194+
l_inv: float,
195+
tol: float,
196+
):
197+
prediction_error, iter_arr = fista_cv_nb(
198+
matrix, s, matrix_test, s_test, max_iter, lambda_vals, nonnegative, l_inv, tol
199+
)
200+
185201
cv = prediction_error.mean(axis=1)
186202
cvstd = prediction_error.std(axis=1)
187-
188203
return cv, cvstd, prediction_error, iter_arr

0 commit comments

Comments
 (0)