Skip to content

Commit b315882

Browse files
authored
Update ChebyshevIteration.java
1 parent e00efa5 commit b315882

File tree

1 file changed

+138
-78
lines changed

1 file changed

+138
-78
lines changed
Lines changed: 138 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,137 +1,197 @@
11
package com.thealgorithms.maths;
22

3-
import java.util.Arrays;
4-
53
/**
6-
* Implements the Chebyshev Iteration method for solving a system of linear
7-
* equations (Ax = b).
4+
* In numerical analysis, Chebyshev iteration is an iterative method for solving
5+
* systems of linear equations Ax = b. It is designed for systems where the
6+
* matrix A is symmetric positive-definite (SPD).
7+
*
8+
* <p>
9+
* This method is a "polynomial acceleration" method, meaning it finds the
10+
* optimal polynomial to apply to the residual to accelerate convergence.
811
*
9-
* This iterative method requires:
10-
* - Matrix a to be symmetric positive definite (SPD)
11-
* - Knowledge of minimum (lambdaMin) and maximum (lambdaMax) eigenvalues
12+
* <p>
13+
* It requires knowledge of the bounds of the eigenvalues of the matrix A:
14+
* m(A) (smallest eigenvalue) and M(A) (largest eigenvalue).
1215
*
13-
* Reference: https://en.wikipedia.org/wiki/Chebyshev_iteration
16+
* <p>
17+
* Wikipedia: https://en.wikipedia.org/wiki/Chebyshev_iteration
1418
*
15-
* Author: Mitrajit Ghorui (github: keyKyrios)
19+
* @author Mitrajit Ghorui(KeyKyrios)
1620
*/
17-
public final class ChebyshevIteration {
21+
public final class Chebyshev {
1822

19-
private ChebyshevIteration() {
23+
private Chebyshev() {
2024
}
2125

2226
/**
23-
* Solves ax = b using Chebyshev Iteration.
27+
* Solves the linear system Ax = b using the Chebyshev iteration method.
2428
*
25-
* @param a SPD matrix
26-
* @param b vector b
27-
* @param x0 initial guess
28-
* @param lambdaMin minimum eigenvalue
29-
* @param lambdaMax maximum eigenvalue
30-
* @param maxIterations maximum iterations
31-
* @param tolerance convergence tolerance
32-
* @return solution vector x
29+
* <p>
30+
* NOTE: The matrix A *must* be symmetric positive-definite (SPD) for this
31+
* algorithm to converge.
32+
*
33+
* @param a The matrix A (must be square, SPD).
34+
* @param b The vector b.
35+
* @param x0 The initial guess vector.
36+
* @param minEigenvalue The smallest eigenvalue of A (m(A)).
37+
* @param maxEigenvalue The largest eigenvalue of A (M(A)).
38+
* @param maxIterations The maximum number of iterations to perform.
39+
* @param tolerance The desired tolerance for the residual norm.
40+
* @return The solution vector x.
41+
* @throws IllegalArgumentException if matrix/vector dimensions are
42+
* incompatible,
43+
* if maxIterations <= 0, or if eigenvalues are invalid (e.g., minEigenvalue
44+
* <= 0, maxEigenvalue <= minEigenvalue).
3345
*/
34-
public static double[] solve(double[][] a, double[] b, double[] x0, double lambdaMin, double lambdaMax,
35-
int maxIterations, double tolerance) {
36-
validateInputs(a, b, x0, lambdaMin, lambdaMax);
46+
public static double[] solve(
47+
double[][] a,
48+
double[] b,
49+
double[] x0,
50+
double minEigenvalue,
51+
double maxEigenvalue,
52+
int maxIterations,
53+
double tolerance
54+
) {
55+
validateInputs(a, b, x0, minEigenvalue, maxEigenvalue, maxIterations, tolerance);
3756

3857
int n = b.length;
39-
double[] x = Arrays.copyOf(x0, n);
58+
double[] x = x0.clone();
4059
double[] r = vectorSubtract(b, matrixVectorMultiply(a, x));
4160
double[] p = new double[n];
61+
62+
double d = (maxEigenvalue + minEigenvalue) / 2.0;
63+
double c = (maxEigenvalue - minEigenvalue) / 2.0;
64+
4265
double alpha = 0.0;
43-
double beta = 0.0;
44-
double c = (lambdaMax - lambdaMin) / 2.0;
45-
double d = (lambdaMax + lambdaMin) / 2.0;
66+
double alphaPrev = 0.0;
4667

47-
double initialResidualNorm = vectorNorm(r);
48-
if (initialResidualNorm < tolerance) {
49-
return x; // Already converged
50-
}
68+
for (int k = 0; k < maxIterations; k++) {
69+
double residualNorm = vectorNorm(r);
70+
if (residualNorm < tolerance) {
71+
return x; // Solution converged
72+
}
5173

52-
for (int k = 1; k <= maxIterations; k++) {
53-
if (k == 1) {
74+
if (k == 0) {
5475
alpha = 1.0 / d;
55-
p = Arrays.copyOf(r, n);
76+
System.arraycopy(r, 0, p, 0, n); // p = r
5677
} else {
57-
double alphaPrev = alpha;
58-
beta = Math.pow(c * alphaPrev / 2.0, 2);
78+
double beta = (c * alphaPrev / 2.0) * (c * alphaPrev / 2.0);
5979
alpha = 1.0 / (d - beta / alphaPrev);
60-
p = vectorAdd(r, vectorScale(p, beta / alphaPrev));
80+
double[] pUpdate = scalarMultiply(beta / alphaPrev, p);
81+
p = vectorAdd(r, pUpdate); // p = r + (beta / alphaPrev) * p
6182
}
6283

63-
x = vectorAdd(x, vectorScale(p, alpha));
64-
r = vectorSubtract(b, matrixVectorMultiply(a, x));
84+
double[] xUpdate = scalarMultiply(alpha, p);
85+
x = vectorAdd(x, xUpdate); // x = x + alpha * p
6586

66-
if (vectorNorm(r) < tolerance) {
67-
break;
68-
}
87+
// Recompute residual for accuracy, though it can be updated iteratively
88+
r = vectorSubtract(b, matrixVectorMultiply(a, x));
89+
alphaPrev = alpha;
6990
}
7091

71-
return x;
92+
return x; // Return best guess after maxIterations
7293
}
7394

74-
private static void validateInputs(double[][] a, double[] b, double[] x0, double lambdaMin, double lambdaMax) {
75-
int n = b.length;
95+
/**
96+
* Validates the inputs for the Chebyshev solver.
97+
*/
98+
private static void validateInputs(
99+
double[][] a,
100+
double[] b,
101+
double[] x0,
102+
double minEigenvalue,
103+
double maxEigenvalue,
104+
int maxIterations,
105+
double tolerance
106+
) {
107+
int n = a.length;
76108
if (n == 0) {
77-
throw new IllegalArgumentException("Vectors cannot be empty.");
109+
throw new IllegalArgumentException("Matrix A cannot be empty.");
110+
}
111+
if (n != a[0].length) {
112+
throw new IllegalArgumentException("Matrix A must be square.");
113+
}
114+
if (n != b.length) {
115+
throw new IllegalArgumentException("Matrix A and vector b dimensions do not match.");
78116
}
79-
if (a.length != n || a[0].length != n) {
80-
throw new IllegalArgumentException("Matrix a must be square (n x n).");
117+
if (n != x0.length) {
118+
throw new IllegalArgumentException("Matrix A and vector x0 dimensions do not match.");
81119
}
82-
if (x0.length != n) {
83-
throw new IllegalArgumentException("Initial guess vector x0 must have length n.");
120+
if (minEigenvalue <= 0) {
121+
throw new IllegalArgumentException("Smallest eigenvalue must be positive (matrix must be positive-definite).");
84122
}
85-
if (lambdaMin >= lambdaMax || lambdaMin <= 0) {
86-
throw new IllegalArgumentException("Eigenvalues must satisfy 0 < lambdaMin < lambdaMax.");
123+
if (maxEigenvalue <= minEigenvalue) {
124+
throw new IllegalArgumentException("Max eigenvalue must be strictly greater than min eigenvalue.");
125+
}
126+
if (maxIterations <= 0) {
127+
throw new IllegalArgumentException("Max iterations must be positive.");
128+
}
129+
if (tolerance <= 0) {
130+
throw new IllegalArgumentException("Tolerance must be positive.");
87131
}
88132
}
89133

90-
private static double[] matrixVectorMultiply(double[][] a, double[] x) {
134+
// --- Vector/Matrix Helper Methods ---
135+
/**
136+
* Computes the product of a matrix A and a vector v (Av).
137+
*/
138+
private static double[] matrixVectorMultiply(double[][] a, double[] v) {
91139
int n = a.length;
92-
double[] y = new double[n];
140+
double[] result = new double[n];
93141
for (int i = 0; i < n; i++) {
94-
double sum = 0.0;
142+
double sum = 0;
95143
for (int j = 0; j < n; j++) {
96-
sum += a[i][j] * x[j];
144+
sum += a[i][j] * v[j];
97145
}
98-
y[i] = sum;
146+
result[i] = sum;
99147
}
100-
return y;
148+
return result;
101149
}
102150

103-
private static double[] vectorAdd(double[] a, double[] b) {
104-
int n = a.length;
105-
double[] c = new double[n];
151+
/**
152+
* Computes the subtraction of two vectors (v1 - v2).
153+
*/
154+
private static double[] vectorSubtract(double[] v1, double[] v2) {
155+
int n = v1.length;
156+
double[] result = new double[n];
106157
for (int i = 0; i < n; i++) {
107-
c[i] = a[i] + b[i];
158+
result[i] = v1[i] - v2[i];
108159
}
109-
return c;
160+
return result;
110161
}
111162

112-
private static double[] vectorSubtract(double[] a, double[] b) {
113-
int n = a.length;
114-
double[] c = new double[n];
163+
/**
164+
* Computes the addition of two vectors (v1 + v2).
165+
*/
166+
private static double[] vectorAdd(double[] v1, double[] v2) {
167+
int n = v1.length;
168+
double[] result = new double[n];
115169
for (int i = 0; i < n; i++) {
116-
c[i] = a[i] - b[i];
170+
result[i] = v1[i] + v2[i];
117171
}
118-
return c;
172+
return result;
119173
}
120174

121-
private static double[] vectorScale(double[] a, double scalar) {
122-
int n = a.length;
123-
double[] c = new double[n];
175+
/**
176+
* Computes the product of a scalar and a vector (s * v).
177+
*/
178+
private static double[] scalarMultiply(double scalar, double[] v) {
179+
int n = v.length;
180+
double[] result = new double[n];
124181
for (int i = 0; i < n; i++) {
125-
c[i] = a[i] * scalar;
182+
result[i] = scalar * v[i];
126183
}
127-
return c;
184+
return result;
128185
}
129186

130-
private static double vectorNorm(double[] a) {
131-
double sum = 0.0;
132-
for (double val : a) {
133-
sum += val * val;
187+
/**
188+
* Computes the L2 norm (Euclidean norm) of a vector.
189+
*/
190+
private static double vectorNorm(double[] v) {
191+
double sumOfSquares = 0;
192+
for (double val : v) {
193+
sumOfSquares += val * val;
134194
}
135-
return Math.sqrt(sum);
195+
return Math.sqrt(sumOfSquares);
136196
}
137197
}

0 commit comments

Comments
 (0)