|
1 | 1 | package com.thealgorithms.maths; |
2 | 2 |
|
3 | | -import java.util.Arrays; |
4 | | - |
5 | 3 | /** |
6 | | - * Implements the Chebyshev Iteration method for solving a system of linear |
7 | | - * equations (Ax = b). |
| 4 | + * In numerical analysis, Chebyshev iteration is an iterative method for solving |
| 5 | + * systems of linear equations Ax = b. It is designed for systems where the |
| 6 | + * matrix A is symmetric positive-definite (SPD). |
| 7 | + * |
| 8 | + * <p> |
| 9 | + * This method is a "polynomial acceleration" method, meaning it finds the |
| 10 | + * optimal polynomial to apply to the residual to accelerate convergence. |
8 | 11 | * |
9 | | - * This iterative method requires: |
10 | | - * - Matrix a to be symmetric positive definite (SPD) |
11 | | - * - Knowledge of minimum (lambdaMin) and maximum (lambdaMax) eigenvalues |
| 12 | + * <p> |
| 13 | + * It requires knowledge of the bounds of the eigenvalues of the matrix A: |
| 14 | + * m(A) (smallest eigenvalue) and M(A) (largest eigenvalue). |
12 | 15 | * |
13 | | - * Reference: https://en.wikipedia.org/wiki/Chebyshev_iteration |
| 16 | + * <p> |
| 17 | + * Wikipedia: https://en.wikipedia.org/wiki/Chebyshev_iteration |
14 | 18 | * |
15 | | - * Author: Mitrajit Ghorui (github: keyKyrios) |
| 19 | + * @author Mitrajit Ghorui(KeyKyrios) |
16 | 20 | */ |
17 | | -public final class ChebyshevIteration { |
| 21 | +public final class Chebyshev { |
18 | 22 |
|
19 | | - private ChebyshevIteration() { |
| 23 | + private Chebyshev() { |
20 | 24 | } |
21 | 25 |
|
22 | 26 | /** |
23 | | - * Solves ax = b using Chebyshev Iteration. |
| 27 | + * Solves the linear system Ax = b using the Chebyshev iteration method. |
24 | 28 | * |
25 | | - * @param a SPD matrix |
26 | | - * @param b vector b |
27 | | - * @param x0 initial guess |
28 | | - * @param lambdaMin minimum eigenvalue |
29 | | - * @param lambdaMax maximum eigenvalue |
30 | | - * @param maxIterations maximum iterations |
31 | | - * @param tolerance convergence tolerance |
32 | | - * @return solution vector x |
| 29 | + * <p> |
| 30 | + * NOTE: The matrix A *must* be symmetric positive-definite (SPD) for this |
| 31 | + * algorithm to converge. |
| 32 | + * |
| 33 | + * @param a The matrix A (must be square, SPD). |
| 34 | + * @param b The vector b. |
| 35 | + * @param x0 The initial guess vector. |
| 36 | + * @param minEigenvalue The smallest eigenvalue of A (m(A)). |
| 37 | + * @param maxEigenvalue The largest eigenvalue of A (M(A)). |
| 38 | + * @param maxIterations The maximum number of iterations to perform. |
| 39 | + * @param tolerance The desired tolerance for the residual norm. |
| 40 | + * @return The solution vector x. |
| 41 | + * @throws IllegalArgumentException if matrix/vector dimensions are |
| 42 | + * incompatible, |
| 43 | + * if maxIterations <= 0, or if eigenvalues are invalid (e.g., minEigenvalue |
| 44 | + * <= 0, maxEigenvalue <= minEigenvalue). |
33 | 45 | */ |
34 | | - public static double[] solve(double[][] a, double[] b, double[] x0, double lambdaMin, double lambdaMax, |
35 | | - int maxIterations, double tolerance) { |
36 | | - validateInputs(a, b, x0, lambdaMin, lambdaMax); |
| 46 | + public static double[] solve( |
| 47 | + double[][] a, |
| 48 | + double[] b, |
| 49 | + double[] x0, |
| 50 | + double minEigenvalue, |
| 51 | + double maxEigenvalue, |
| 52 | + int maxIterations, |
| 53 | + double tolerance |
| 54 | + ) { |
| 55 | + validateInputs(a, b, x0, minEigenvalue, maxEigenvalue, maxIterations, tolerance); |
37 | 56 |
|
38 | 57 | int n = b.length; |
39 | | - double[] x = Arrays.copyOf(x0, n); |
| 58 | + double[] x = x0.clone(); |
40 | 59 | double[] r = vectorSubtract(b, matrixVectorMultiply(a, x)); |
41 | 60 | double[] p = new double[n]; |
| 61 | + |
| 62 | + double d = (maxEigenvalue + minEigenvalue) / 2.0; |
| 63 | + double c = (maxEigenvalue - minEigenvalue) / 2.0; |
| 64 | + |
42 | 65 | double alpha = 0.0; |
43 | | - double beta = 0.0; |
44 | | - double c = (lambdaMax - lambdaMin) / 2.0; |
45 | | - double d = (lambdaMax + lambdaMin) / 2.0; |
| 66 | + double alphaPrev = 0.0; |
46 | 67 |
|
47 | | - double initialResidualNorm = vectorNorm(r); |
48 | | - if (initialResidualNorm < tolerance) { |
49 | | - return x; // Already converged |
50 | | - } |
| 68 | + for (int k = 0; k < maxIterations; k++) { |
| 69 | + double residualNorm = vectorNorm(r); |
| 70 | + if (residualNorm < tolerance) { |
| 71 | + return x; // Solution converged |
| 72 | + } |
51 | 73 |
|
52 | | - for (int k = 1; k <= maxIterations; k++) { |
53 | | - if (k == 1) { |
| 74 | + if (k == 0) { |
54 | 75 | alpha = 1.0 / d; |
55 | | - p = Arrays.copyOf(r, n); |
| 76 | + System.arraycopy(r, 0, p, 0, n); // p = r |
56 | 77 | } else { |
57 | | - double alphaPrev = alpha; |
58 | | - beta = Math.pow(c * alphaPrev / 2.0, 2); |
| 78 | + double beta = (c * alphaPrev / 2.0) * (c * alphaPrev / 2.0); |
59 | 79 | alpha = 1.0 / (d - beta / alphaPrev); |
60 | | - p = vectorAdd(r, vectorScale(p, beta / alphaPrev)); |
| 80 | + double[] pUpdate = scalarMultiply(beta / alphaPrev, p); |
| 81 | + p = vectorAdd(r, pUpdate); // p = r + (beta / alphaPrev) * p |
61 | 82 | } |
62 | 83 |
|
63 | | - x = vectorAdd(x, vectorScale(p, alpha)); |
64 | | - r = vectorSubtract(b, matrixVectorMultiply(a, x)); |
| 84 | + double[] xUpdate = scalarMultiply(alpha, p); |
| 85 | + x = vectorAdd(x, xUpdate); // x = x + alpha * p |
65 | 86 |
|
66 | | - if (vectorNorm(r) < tolerance) { |
67 | | - break; |
68 | | - } |
| 87 | + // Recompute residual for accuracy, though it can be updated iteratively |
| 88 | + r = vectorSubtract(b, matrixVectorMultiply(a, x)); |
| 89 | + alphaPrev = alpha; |
69 | 90 | } |
70 | 91 |
|
71 | | - return x; |
| 92 | + return x; // Return best guess after maxIterations |
72 | 93 | } |
73 | 94 |
|
74 | | - private static void validateInputs(double[][] a, double[] b, double[] x0, double lambdaMin, double lambdaMax) { |
75 | | - int n = b.length; |
| 95 | + /** |
| 96 | + * Validates the inputs for the Chebyshev solver. |
| 97 | + */ |
| 98 | + private static void validateInputs( |
| 99 | + double[][] a, |
| 100 | + double[] b, |
| 101 | + double[] x0, |
| 102 | + double minEigenvalue, |
| 103 | + double maxEigenvalue, |
| 104 | + int maxIterations, |
| 105 | + double tolerance |
| 106 | + ) { |
| 107 | + int n = a.length; |
76 | 108 | if (n == 0) { |
77 | | - throw new IllegalArgumentException("Vectors cannot be empty."); |
| 109 | + throw new IllegalArgumentException("Matrix A cannot be empty."); |
| 110 | + } |
| 111 | + if (n != a[0].length) { |
| 112 | + throw new IllegalArgumentException("Matrix A must be square."); |
| 113 | + } |
| 114 | + if (n != b.length) { |
| 115 | + throw new IllegalArgumentException("Matrix A and vector b dimensions do not match."); |
78 | 116 | } |
79 | | - if (a.length != n || a[0].length != n) { |
80 | | - throw new IllegalArgumentException("Matrix a must be square (n x n)."); |
| 117 | + if (n != x0.length) { |
| 118 | + throw new IllegalArgumentException("Matrix A and vector x0 dimensions do not match."); |
81 | 119 | } |
82 | | - if (x0.length != n) { |
83 | | - throw new IllegalArgumentException("Initial guess vector x0 must have length n."); |
| 120 | + if (minEigenvalue <= 0) { |
| 121 | + throw new IllegalArgumentException("Smallest eigenvalue must be positive (matrix must be positive-definite)."); |
84 | 122 | } |
85 | | - if (lambdaMin >= lambdaMax || lambdaMin <= 0) { |
86 | | - throw new IllegalArgumentException("Eigenvalues must satisfy 0 < lambdaMin < lambdaMax."); |
| 123 | + if (maxEigenvalue <= minEigenvalue) { |
| 124 | + throw new IllegalArgumentException("Max eigenvalue must be strictly greater than min eigenvalue."); |
| 125 | + } |
| 126 | + if (maxIterations <= 0) { |
| 127 | + throw new IllegalArgumentException("Max iterations must be positive."); |
| 128 | + } |
| 129 | + if (tolerance <= 0) { |
| 130 | + throw new IllegalArgumentException("Tolerance must be positive."); |
87 | 131 | } |
88 | 132 | } |
89 | 133 |
|
90 | | - private static double[] matrixVectorMultiply(double[][] a, double[] x) { |
| 134 | + // --- Vector/Matrix Helper Methods --- |
| 135 | + /** |
| 136 | + * Computes the product of a matrix A and a vector v (Av). |
| 137 | + */ |
| 138 | + private static double[] matrixVectorMultiply(double[][] a, double[] v) { |
91 | 139 | int n = a.length; |
92 | | - double[] y = new double[n]; |
| 140 | + double[] result = new double[n]; |
93 | 141 | for (int i = 0; i < n; i++) { |
94 | | - double sum = 0.0; |
| 142 | + double sum = 0; |
95 | 143 | for (int j = 0; j < n; j++) { |
96 | | - sum += a[i][j] * x[j]; |
| 144 | + sum += a[i][j] * v[j]; |
97 | 145 | } |
98 | | - y[i] = sum; |
| 146 | + result[i] = sum; |
99 | 147 | } |
100 | | - return y; |
| 148 | + return result; |
101 | 149 | } |
102 | 150 |
|
103 | | - private static double[] vectorAdd(double[] a, double[] b) { |
104 | | - int n = a.length; |
105 | | - double[] c = new double[n]; |
| 151 | + /** |
| 152 | + * Computes the subtraction of two vectors (v1 - v2). |
| 153 | + */ |
| 154 | + private static double[] vectorSubtract(double[] v1, double[] v2) { |
| 155 | + int n = v1.length; |
| 156 | + double[] result = new double[n]; |
106 | 157 | for (int i = 0; i < n; i++) { |
107 | | - c[i] = a[i] + b[i]; |
| 158 | + result[i] = v1[i] - v2[i]; |
108 | 159 | } |
109 | | - return c; |
| 160 | + return result; |
110 | 161 | } |
111 | 162 |
|
112 | | - private static double[] vectorSubtract(double[] a, double[] b) { |
113 | | - int n = a.length; |
114 | | - double[] c = new double[n]; |
| 163 | + /** |
| 164 | + * Computes the addition of two vectors (v1 + v2). |
| 165 | + */ |
| 166 | + private static double[] vectorAdd(double[] v1, double[] v2) { |
| 167 | + int n = v1.length; |
| 168 | + double[] result = new double[n]; |
115 | 169 | for (int i = 0; i < n; i++) { |
116 | | - c[i] = a[i] - b[i]; |
| 170 | + result[i] = v1[i] + v2[i]; |
117 | 171 | } |
118 | | - return c; |
| 172 | + return result; |
119 | 173 | } |
120 | 174 |
|
121 | | - private static double[] vectorScale(double[] a, double scalar) { |
122 | | - int n = a.length; |
123 | | - double[] c = new double[n]; |
| 175 | + /** |
| 176 | + * Computes the product of a scalar and a vector (s * v). |
| 177 | + */ |
| 178 | + private static double[] scalarMultiply(double scalar, double[] v) { |
| 179 | + int n = v.length; |
| 180 | + double[] result = new double[n]; |
124 | 181 | for (int i = 0; i < n; i++) { |
125 | | - c[i] = a[i] * scalar; |
| 182 | + result[i] = scalar * v[i]; |
126 | 183 | } |
127 | | - return c; |
| 184 | + return result; |
128 | 185 | } |
129 | 186 |
|
130 | | - private static double vectorNorm(double[] a) { |
131 | | - double sum = 0.0; |
132 | | - for (double val : a) { |
133 | | - sum += val * val; |
| 187 | + /** |
| 188 | + * Computes the L2 norm (Euclidean norm) of a vector. |
| 189 | + */ |
| 190 | + private static double vectorNorm(double[] v) { |
| 191 | + double sumOfSquares = 0; |
| 192 | + for (double val : v) { |
| 193 | + sumOfSquares += val * val; |
134 | 194 | } |
135 | | - return Math.sqrt(sum); |
| 195 | + return Math.sqrt(sumOfSquares); |
136 | 196 | } |
137 | 197 | } |
0 commit comments