|
| 1 | +import numba as nb |
| 2 | +import numpy as np |
| 3 | +from numpy.linalg import norm |
| 4 | + |
| 5 | + |
| 6 | +@nb.njit(fastmath=True) |
| 7 | +def l1_soft_threshold(a: np.ndarray, c: float): |
| 8 | + """Soft threshold without non-negative constraint.""" |
| 9 | + b = np.abs(a) - c |
| 10 | + b = np.where(b >= 0, b * np.sign(a), 0.0) |
| 11 | + return b |
| 12 | + |
| 13 | + |
| 14 | +@nb.njit(fastmath=True) |
| 15 | +def l1_soft_threshold_nonnegative(a: np.ndarray, c: float): |
| 16 | + """Soft threshold with non-negative constraint.""" |
| 17 | + b = a - c |
| 18 | + b += np.abs(b) |
| 19 | + b /= 2.0 |
| 20 | + return b |
| 21 | + |
| 22 | + |
| 23 | +@nb.njit(fastmath=True) |
| 24 | +def fista( |
| 25 | + matrix: np.ndarray, |
| 26 | + s: np.ndarray, |
| 27 | + f_k: np.ndarray, |
| 28 | + lam: float, |
| 29 | + max_iter: int, |
| 30 | + tol: float, |
| 31 | + l_inv: float, |
| 32 | + nonnegative: bool = False, |
| 33 | +): |
| 34 | + matrix = np.asarray(matrix) |
| 35 | + s = np.asarray(s) |
| 36 | + |
| 37 | + residue = np.zeros(max_iter) |
| 38 | + data_consistency = np.zeros(max_iter) |
| 39 | + |
| 40 | + gradient = matrix.T @ matrix |
| 41 | + c = matrix.T @ s |
| 42 | + |
| 43 | + last_data_consistency = 0.0 |
| 44 | + normalization_factor = norm(s) ** 2 |
| 45 | + last_data_consistency = float(normalization_factor) |
| 46 | + |
| 47 | + last_fk_l1 = 0.0 |
| 48 | + t_kp1 = 1.0 |
| 49 | + y_k = f_k.copy() |
| 50 | + |
| 51 | + for k in range(max_iter): |
| 52 | + t_k = t_kp1 |
| 53 | + f_km1 = f_k.copy() |
| 54 | + |
| 55 | + t_kp1 = (1.0 + np.sqrt(1.0 + 4.0 * t_k**2)) * 0.5 |
| 56 | + constant_factor = (t_k - 1.0) / t_kp1 |
| 57 | + |
| 58 | + temp_c = gradient @ y_k |
| 59 | + temp_c = l_inv * (c - temp_c) |
| 60 | + temp_c += y_k |
| 61 | + |
| 62 | + if nonnegative: |
| 63 | + f_k[:] = l1_soft_threshold_nonnegative(temp_c, lam * l_inv) |
| 64 | + else: |
| 65 | + f_k[:] = l1_soft_threshold(temp_c, lam * l_inv) |
| 66 | + |
| 67 | + temp = (matrix @ f_k) - s |
| 68 | + residue_temp = norm(temp) ** 2 |
| 69 | + fk_l1 = np.sum(np.abs(f_k)) |
| 70 | + |
| 71 | + residue[k] = residue_temp |
| 72 | + data_consistency[k] = residue[k] + lam * fk_l1 |
| 73 | + residue[k] = residue[k] / normalization_factor |
| 74 | + |
| 75 | + if k >= 5: |
| 76 | + recent_avg = np.mean(residue[k - 5 : k]) |
| 77 | + if abs(1.0 - (recent_avg / residue[k])) <= tol: |
| 78 | + break |
| 79 | + |
| 80 | + # data consistency check |
| 81 | + if data_consistency[k] > last_data_consistency: |
| 82 | + f_k[:] = f_km1.copy() |
| 83 | + fk_l1 = last_fk_l1 |
| 84 | + data_consistency[k] = data_consistency[k - 1] |
| 85 | + residue[k] = residue[k - 1] |
| 86 | + |
| 87 | + last_data_consistency = data_consistency[k] |
| 88 | + last_fk_l1 = fk_l1 |
| 89 | + |
| 90 | + y_k = f_k + constant_factor * (f_k - f_km1) |
| 91 | + |
| 92 | + zf = matrix @ f_k |
| 93 | + iter = k |
| 94 | + |
| 95 | + return zf, f_k, residue[: iter + 1], data_consistency[: iter + 1], iter |
| 96 | + |
| 97 | + |
| 98 | +@nb.njit(fastmath=True) |
| 99 | +def fista_cv_nb( |
| 100 | + matrix: np.ndarray, |
| 101 | + s: np.ndarray, |
| 102 | + matrix_test: np.ndarray, |
| 103 | + s_test: np.ndarray, |
| 104 | + max_iter: int, |
| 105 | + lambda_vals: np.ndarray, |
| 106 | + nonnegative: bool, |
| 107 | + l_inv: float, |
| 108 | + tol: float, |
| 109 | +): |
| 110 | + n_fold = matrix.shape[-1] |
| 111 | + n_lambda = len(lambda_vals) |
| 112 | + n_targets = s.shape[1] |
| 113 | + n_features = matrix.shape[1] |
| 114 | + prediction_error = np.zeros((n_lambda, n_fold)) |
| 115 | + iter_arr = np.zeros(n_lambda) |
| 116 | + |
| 117 | + residue = np.zeros(max_iter) |
| 118 | + data_consistency = np.zeros(max_iter) |
| 119 | + |
| 120 | + for fold in range(n_fold): |
| 121 | + x_train = matrix[..., fold] |
| 122 | + y_train = s[..., fold] |
| 123 | + x_test = matrix_test[..., fold] |
| 124 | + y_test = s_test[..., fold] |
| 125 | + y_points = y_test.shape[0] * y_test.shape[1] |
| 126 | + |
| 127 | + gradient = x_train.T @ x_train |
| 128 | + c = x_train.T @ y_train |
| 129 | + |
| 130 | + norm_factor = norm(y_train) ** 2 |
| 131 | + f_k = np.zeros((n_features, n_targets)) |
| 132 | + y_k = f_k.copy() |
| 133 | + |
| 134 | + for j, lam in enumerate(lambda_vals): |
| 135 | + t_kp1 = 1.0 |
| 136 | + y_k[:] = 0.0 |
| 137 | + f_k[:] = 0.0 |
| 138 | + residue[:] = 0 |
| 139 | + fk_l1 = 0.0 |
| 140 | + last_fk_l1 = 0.0 |
| 141 | + data_consistency[:] = 0 |
| 142 | + last_data_consistency = norm_factor |
| 143 | + |
| 144 | + for k in range(max_iter): |
| 145 | + t_k = t_kp1 |
| 146 | + f_km1 = f_k.copy() |
| 147 | + t_kp1 = (1.0 + np.sqrt(1.0 + 4.0 * t_k**2)) / 2.0 |
| 148 | + constant_factor = (t_k - 1.0) / t_kp1 |
| 149 | + |
| 150 | + grad_yk = gradient @ y_k |
| 151 | + temp_c = y_k - l_inv * (grad_yk - c) |
| 152 | + |
| 153 | + if nonnegative: |
| 154 | + f_k[:] = l1_soft_threshold_nonnegative(temp_c, l_inv * lam) |
| 155 | + else: |
| 156 | + f_k[:] = l1_soft_threshold(temp_c, l_inv * lam) |
| 157 | + |
| 158 | + residue[k] = norm(x_train @ f_k - y_train) ** 2 |
| 159 | + fk_l1 = np.sum(np.abs(f_k)) |
| 160 | + data_consistency[k] = residue[k] + lam * fk_l1 |
| 161 | + |
| 162 | + if k >= 5: |
| 163 | + recent_avg = np.mean(residue[k - 5 : k]) |
| 164 | + if abs(1.0 - (recent_avg / residue[k])) <= tol: |
| 165 | + break |
| 166 | + |
| 167 | + # data consistency check |
| 168 | + if data_consistency[k] > last_data_consistency: |
| 169 | + f_k[:] = f_km1.copy() |
| 170 | + fk_l1 = last_fk_l1 |
| 171 | + data_consistency[k] = data_consistency[k - 1] |
| 172 | + residue[k] = residue[k - 1] |
| 173 | + |
| 174 | + last_data_consistency = data_consistency[k] |
| 175 | + last_fk_l1 = fk_l1 |
| 176 | + |
| 177 | + y_k = f_k + constant_factor * (f_k - f_km1) |
| 178 | + |
| 179 | + err = np.linalg.norm(x_test @ f_k - y_test) ** 2 |
| 180 | + prediction_error[j, fold] = err / y_points |
| 181 | + iter_arr[j] = k |
| 182 | + |
| 183 | + return prediction_error, iter_arr |
| 184 | + |
| 185 | + |
| 186 | +def fista_cv( |
| 187 | + matrix: np.ndarray, |
| 188 | + s: np.ndarray, |
| 189 | + matrix_test: np.ndarray, |
| 190 | + s_test: np.ndarray, |
| 191 | + max_iter: int, |
| 192 | + lambda_vals: np.ndarray, |
| 193 | + nonnegative: bool, |
| 194 | + l_inv: float, |
| 195 | + tol: float, |
| 196 | +): |
| 197 | + prediction_error, iter_arr = fista_cv_nb( |
| 198 | + matrix, s, matrix_test, s_test, max_iter, lambda_vals, nonnegative, l_inv, tol |
| 199 | + ) |
| 200 | + |
| 201 | + cv = prediction_error.mean(axis=1) |
| 202 | + cvstd = prediction_error.std(axis=1) |
| 203 | + return cv, cvstd, prediction_error, iter_arr |
0 commit comments