|
| 1 | +package com.thealgorithms.maths; |
| 2 | + |
| 3 | +/** |
| 4 | + * In numerical analysis, Chebyshev iteration is an iterative method for solving |
| 5 | + * systems of linear equations Ax = b. It is designed for systems where the |
| 6 | + * matrix A is symmetric positive-definite (SPD). |
| 7 | + * |
| 8 | + * <p> |
| 9 | + * This method is a "polynomial acceleration" method, meaning it finds the |
| 10 | + * optimal polynomial to apply to the residual to accelerate convergence. |
| 11 | + * |
| 12 | + * <p> |
| 13 | + * It requires knowledge of the bounds of the eigenvalues of the matrix A: |
| 14 | + * m(A) (smallest eigenvalue) and M(A) (largest eigenvalue). |
| 15 | + * |
| 16 | + * <p> |
| 17 | + * Wikipedia: https://en.wikipedia.org/wiki/Chebyshev_iteration |
| 18 | + * |
| 19 | + * @author Mitrajit Ghorui(KeyKyrios) |
| 20 | + */ |
| 21 | +public final class ChebyshevIteration { |
| 22 | + |
| 23 | + private ChebyshevIteration() { |
| 24 | + } |
| 25 | + |
| 26 | + /** |
| 27 | + * Solves the linear system Ax = b using the Chebyshev iteration method. |
| 28 | + * |
| 29 | + * <p> |
| 30 | + * NOTE: The matrix A *must* be symmetric positive-definite (SPD) for this |
| 31 | + * algorithm to converge. |
| 32 | + * |
| 33 | + * @param a The matrix A (must be square, SPD). |
| 34 | + * @param b The vector b. |
| 35 | + * @param x0 The initial guess vector. |
| 36 | + * @param minEigenvalue The smallest eigenvalue of A (m(A)). |
| 37 | + * @param maxEigenvalue The largest eigenvalue of A (M(A)). |
| 38 | + * @param maxIterations The maximum number of iterations to perform. |
| 39 | + * @param tolerance The desired tolerance for the residual norm. |
| 40 | + * @return The solution vector x. |
| 41 | + * @throws IllegalArgumentException if matrix/vector dimensions are |
| 42 | + * incompatible, |
| 43 | + * if maxIterations <= 0, or if eigenvalues are invalid (e.g., minEigenvalue |
| 44 | + * <= 0, maxEigenvalue <= minEigenvalue). |
| 45 | + */ |
| 46 | + public static double[] solve(double[][] a, double[] b, double[] x0, double minEigenvalue, double maxEigenvalue, int maxIterations, double tolerance) { |
| 47 | + validateInputs(a, b, x0, minEigenvalue, maxEigenvalue, maxIterations, tolerance); |
| 48 | + |
| 49 | + int n = b.length; |
| 50 | + double[] x = x0.clone(); |
| 51 | + double[] r = vectorSubtract(b, matrixVectorMultiply(a, x)); |
| 52 | + double[] p = new double[n]; |
| 53 | + |
| 54 | + double d = (maxEigenvalue + minEigenvalue) / 2.0; |
| 55 | + double c = (maxEigenvalue - minEigenvalue) / 2.0; |
| 56 | + |
| 57 | + double alpha = 0.0; |
| 58 | + double alphaPrev = 0.0; |
| 59 | + |
| 60 | + for (int k = 0; k < maxIterations; k++) { |
| 61 | + double residualNorm = vectorNorm(r); |
| 62 | + if (residualNorm < tolerance) { |
| 63 | + return x; // Solution converged |
| 64 | + } |
| 65 | + |
| 66 | + if (k == 0) { |
| 67 | + alpha = 1.0 / d; |
| 68 | + System.arraycopy(r, 0, p, 0, n); // p = r |
| 69 | + } else { |
| 70 | + double beta = c * alphaPrev / 2.0 * (c * alphaPrev / 2.0); |
| 71 | + alpha = 1.0 / (d - beta / alphaPrev); |
| 72 | + double[] pUpdate = scalarMultiply(beta / alphaPrev, p); |
| 73 | + p = vectorAdd(r, pUpdate); // p = r + (beta / alphaPrev) * p |
| 74 | + } |
| 75 | + |
| 76 | + double[] xUpdate = scalarMultiply(alpha, p); |
| 77 | + x = vectorAdd(x, xUpdate); // x = x + alpha * p |
| 78 | + |
| 79 | + // Recompute residual for accuracy |
| 80 | + r = vectorSubtract(b, matrixVectorMultiply(a, x)); |
| 81 | + alphaPrev = alpha; |
| 82 | + } |
| 83 | + |
| 84 | + return x; // Return best guess after maxIterations |
| 85 | + } |
| 86 | + |
| 87 | + /** |
| 88 | + * Validates the inputs for the Chebyshev solver. |
| 89 | + */ |
| 90 | + private static void validateInputs(double[][] a, double[] b, double[] x0, double minEigenvalue, double maxEigenvalue, int maxIterations, double tolerance) { |
| 91 | + int n = a.length; |
| 92 | + if (n == 0) { |
| 93 | + throw new IllegalArgumentException("Matrix A cannot be empty."); |
| 94 | + } |
| 95 | + if (n != a[0].length) { |
| 96 | + throw new IllegalArgumentException("Matrix A must be square."); |
| 97 | + } |
| 98 | + if (n != b.length) { |
| 99 | + throw new IllegalArgumentException("Matrix A and vector b dimensions do not match."); |
| 100 | + } |
| 101 | + if (n != x0.length) { |
| 102 | + throw new IllegalArgumentException("Matrix A and vector x0 dimensions do not match."); |
| 103 | + } |
| 104 | + if (minEigenvalue <= 0) { |
| 105 | + throw new IllegalArgumentException("Smallest eigenvalue must be positive (matrix must be positive-definite)."); |
| 106 | + } |
| 107 | + if (maxEigenvalue <= minEigenvalue) { |
| 108 | + throw new IllegalArgumentException("Max eigenvalue must be strictly greater than min eigenvalue."); |
| 109 | + } |
| 110 | + if (maxIterations <= 0) { |
| 111 | + throw new IllegalArgumentException("Max iterations must be positive."); |
| 112 | + } |
| 113 | + if (tolerance <= 0) { |
| 114 | + throw new IllegalArgumentException("Tolerance must be positive."); |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + // --- Vector/Matrix Helper Methods --- |
| 119 | + /** |
| 120 | + * Computes the product of a matrix A and a vector v (Av). |
| 121 | + */ |
| 122 | + private static double[] matrixVectorMultiply(double[][] a, double[] v) { |
| 123 | + int n = a.length; |
| 124 | + double[] result = new double[n]; |
| 125 | + for (int i = 0; i < n; i++) { |
| 126 | + double sum = 0; |
| 127 | + for (int j = 0; j < n; j++) { |
| 128 | + sum += a[i][j] * v[j]; |
| 129 | + } |
| 130 | + result[i] = sum; |
| 131 | + } |
| 132 | + return result; |
| 133 | + } |
| 134 | + |
| 135 | + /** |
| 136 | + * Computes the subtraction of two vectors (v1 - v2). |
| 137 | + */ |
| 138 | + private static double[] vectorSubtract(double[] v1, double[] v2) { |
| 139 | + int n = v1.length; |
| 140 | + double[] result = new double[n]; |
| 141 | + for (int i = 0; i < n; i++) { |
| 142 | + result[i] = v1[i] - v2[i]; |
| 143 | + } |
| 144 | + return result; |
| 145 | + } |
| 146 | + |
| 147 | + /** |
| 148 | + * Computes the addition of two vectors (v1 + v2). |
| 149 | + */ |
| 150 | + private static double[] vectorAdd(double[] v1, double[] v2) { |
| 151 | + int n = v1.length; |
| 152 | + double[] result = new double[n]; |
| 153 | + for (int i = 0; i < n; i++) { |
| 154 | + result[i] = v1[i] + v2[i]; |
| 155 | + } |
| 156 | + return result; |
| 157 | + } |
| 158 | + |
| 159 | + /** |
| 160 | + * Computes the product of a scalar and a vector (s * v). |
| 161 | + */ |
| 162 | + private static double[] scalarMultiply(double scalar, double[] v) { |
| 163 | + int n = v.length; |
| 164 | + double[] result = new double[n]; |
| 165 | + for (int i = 0; i < n; i++) { |
| 166 | + result[i] = scalar * v[i]; |
| 167 | + } |
| 168 | + return result; |
| 169 | + } |
| 170 | + |
| 171 | + /** |
| 172 | + * Computes the L2 norm (Euclidean norm) of a vector. |
| 173 | + */ |
| 174 | + private static double vectorNorm(double[] v) { |
| 175 | + double sumOfSquares = 0; |
| 176 | + for (double val : v) { |
| 177 | + sumOfSquares += val * val; |
| 178 | + } |
| 179 | + return Math.sqrt(sumOfSquares); |
| 180 | + } |
| 181 | +} |
0 commit comments