Skip to content

Commit 61a5fbf

Browse files
authored
Update ChebyshevIteration.java
1 parent 4796867 commit 61a5fbf

File tree

1 file changed

+41
-85
lines changed

1 file changed

+41
-85
lines changed

src/main/java/com/thealgorithms/maths/ChebyshevIteration.java

Lines changed: 41 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -6,49 +6,39 @@
66
* Implements the Chebyshev Iteration method for solving a system of linear
77
* equations (Ax = b).
88
*
9-
* @author Mitrajit ghorui (github: keyKyrios)
10-
* @see <a href="https://en.wikipedia.org/wiki/Chebyshev_iteration">Wikipedia
11-
* Page</a>
9+
* This iterative method requires:
10+
* - Matrix A to be symmetric positive definite (SPD)
11+
* - Knowledge of minimum (lambdaMin) and maximum (lambdaMax) eigenvalues
1212
*
13-
* This is an iterative method that requires the matrix A to be
14-
* symmetric positive definite (SPD).
15-
* It also requires knowledge of the minimum (lambdaMin) and maximum
16-
* (lambdaMax)
17-
* eigenvalues of the matrix A.
13+
* Reference: https://en.wikipedia.org/wiki/Chebyshev_iteration
1814
*
19-
* The algorithm converges faster than simpler methods like Jacobi or
20-
* Gauss-Seidel
21-
* by using Chebyshev polynomials to optimize the update steps.
15+
* Author: Mitrajit Ghorui (github: keyKyrios)
2216
*/
2317
public final class ChebyshevIteration {
2418

25-
/**
26-
* Private constructor to prevent instantiation of this utility class.
27-
*/
2819
private ChebyshevIteration() {
2920
}
3021

3122
/**
32-
* Solves the linear system Ax = b using Chebyshev Iteration.
23+
* Solves Ax = b using Chebyshev Iteration.
3324
*
34-
* @param a A symmetric positive definite matrix.
35-
* @param b The vector 'b' in the equation Ax = b.
36-
* @param x0 An initial guess vector for 'x'.
37-
* @param lambdaMin The smallest eigenvalue of matrix A.
38-
* @param lambdaMax The largest eigenvalue of matrix A.
39-
* @param maxIterations The maximum number of iterations to perform.
40-
* @param tolerance The desired tolerance for convergence (e.g., 1e-10).
41-
* @return The solution vector 'x'.
42-
* @throws IllegalArgumentException if matrix/vector dimensions don't
43-
* match,
44-
* or if max/min eigenvalues are invalid.
25+
* @param A SPD matrix
26+
* @param b vector b
27+
* @param x0 initial guess
28+
* @param lambdaMin minimum eigenvalue
29+
* @param lambdaMax maximum eigenvalue
30+
* @param maxIterations maximum iterations
31+
* @param tolerance convergence tolerance
32+
* @return solution vector x
4533
*/
46-
public static double[] solve(double[][] a, double[] b, double[] x0, double lambdaMin, double lambdaMax, int maxIterations, double tolerance) {
47-
validateInputs(a, b, x0, lambdaMin, lambdaMax);
34+
public static double[] solve(double[][] A, double[] b, double[] x0,
35+
double lambdaMin, double lambdaMax,
36+
int maxIterations, double tolerance) {
37+
validateInputs(A, b, x0, lambdaMin, lambdaMax);
4838

4939
int n = b.length;
5040
double[] x = Arrays.copyOf(x0, n);
51-
double[] r = vectorSubtract(b, matrixVectorMultiply(a, x)); // Use `a`
41+
double[] r = vectorSubtract(b, matrixVectorMultiply(A, x));
5242
double[] p = new double[n];
5343
double alpha = 0.0;
5444
double beta = 0.0;
@@ -66,105 +56,71 @@ public static double[] solve(double[][] a, double[] b, double[] x0, double lambd
6656
p = Arrays.copyOf(r, n);
6757
} else {
6858
double alphaPrev = alpha;
69-
70-
beta = (c * alphaPrev / 2.0) * (c * alphaPrev / 2.0);
59+
beta = Math.pow(c * alphaPrev / 2.0, 2);
7160
alpha = 1.0 / (d - beta / alphaPrev);
72-
73-
double betaOverAlphaPrev = beta / alphaPrev;
74-
double[] rScaled = vectorScale(p, betaOverAlphaPrev);
75-
p = vectorAdd(r, rScaled);
61+
p = vectorAdd(r, vectorScale(p, beta / alphaPrev));
7662
}
7763

78-
double[] pScaled = vectorScale(p, alpha);
79-
x = vectorAdd(x, pScaled);
80-
81-
// Re-calculate residual to avoid accumulating floating-point errors
82-
r = vectorSubtract(b, matrixVectorMultiply(a, x)); // Use `a`
64+
x = vectorAdd(x, vectorScale(p, alpha));
65+
r = vectorSubtract(b, matrixVectorMultiply(A, x));
8366

8467
if (vectorNorm(r) < tolerance) {
85-
break; // Converged
68+
break;
8669
}
8770
}
71+
8872
return x;
8973
}
9074

91-
// --- Helper Methods for Linear Algebra ---
92-
private static void validateInputs(double[][] a, double[] b, double[] x0, double lambdaMin, double lambdaMax) {
75+
private static void validateInputs(double[][] A, double[] b, double[] x0,
76+
double lambdaMin, double lambdaMax) {
9377
int n = b.length;
94-
if (n == 0) {
95-
throw new IllegalArgumentException("Vectors cannot be empty.");
96-
}
97-
if (a.length != n || a[0].length != n) { // Use `a`
98-
throw new IllegalArgumentException("Matrix A must be square with dimensions n x n.");
99-
}
100-
if (x0.length != n) {
78+
if (n == 0) throw new IllegalArgumentException("Vectors cannot be empty.");
79+
if (A.length != n || A[0].length != n)
80+
throw new IllegalArgumentException("Matrix A must be square (n x n).");
81+
if (x0.length != n)
10182
throw new IllegalArgumentException("Initial guess vector x0 must have length n.");
102-
}
103-
if (lambdaMin >= lambdaMax || lambdaMin <= 0) {
83+
if (lambdaMin >= lambdaMax || lambdaMin <= 0)
10484
throw new IllegalArgumentException("Eigenvalues must satisfy 0 < lambdaMin < lambdaMax.");
105-
}
10685
}
10786

108-
/**
109-
* Computes y = Ax
110-
*/
111-
private static double[] matrixVectorMultiply(double[][] a, double[] x) {
112-
int n = a.length; // Use `a`
87+
private static double[] matrixVectorMultiply(double[][] A, double[] x) {
88+
int n = A.length;
11389
double[] y = new double[n];
11490
for (int i = 0; i < n; i++) {
11591
double sum = 0.0;
11692
for (int j = 0; j < n; j++) {
117-
sum += a[i][j] * x[j]; // Use `a`
93+
sum += A[i][j] * x[j];
11894
}
11995
y[i] = sum;
12096
}
12197
return y;
12298
}
12399

124-
/**
125-
* Computes c = a + b
126-
*/
127100
private static double[] vectorAdd(double[] a, double[] b) {
128101
int n = a.length;
129102
double[] c = new double[n];
130-
for (int i = 0; i < n; i++) {
131-
c[i] = a[i] + b[i];
132-
}
103+
for (int i = 0; i < n; i++) c[i] = a[i] + b[i];
133104
return c;
134105
}
135106

136-
/**
137-
* Computes c = a - b
138-
*/
139107
private static double[] vectorSubtract(double[] a, double[] b) {
140108
int n = a.length;
141109
double[] c = new double[n];
142-
for (int i = 0; i < n; i++) {
143-
c[i] = a[i] - b[i];
144-
}
110+
for (int i = 0; i < n; i++) c[i] = a[i] - b[i];
145111
return c;
146112
}
147113

148-
/**
149-
* Computes c = a * scalar
150-
*/
151114
private static double[] vectorScale(double[] a, double scalar) {
152115
int n = a.length;
153116
double[] c = new double[n];
154-
for (int i = 0; i < n; i++) {
155-
c[i] = a[i] * scalar;
156-
}
117+
for (int i = 0; i < n; i++) c[i] = a[i] * scalar;
157118
return c;
158119
}
159120

160-
/**
161-
* Computes the L2 norm (Euclidean norm) of a vector
162-
*/
163121
private static double vectorNorm(double[] a) {
164-
double sumOfSquares = 0.0;
165-
for (double val : a) {
166-
sumOfSquares += val * val;
167-
}
168-
return Math.sqrt(sumOfSquares);
122+
double sum = 0.0;
123+
for (double val : a) sum += val * val;
124+
return Math.sqrt(sum);
169125
}
170126
}

0 commit comments

Comments
 (0)