Skip to content

Commit 155ae56

Browse files
committed
feat: Add Chebyshev Iteration algorithm
1 parent ae2e40a commit 155ae56

File tree

2 files changed

+366
-0
lines changed

2 files changed

+366
-0
lines changed
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
package com.thealgorithms.maths;
2+
3+
import java.util.Arrays;
4+
5+
/**
6+
* Implements the Chebyshev Iteration method for solving a system of linear
7+
* equations (Ax = b).
8+
*
9+
* @author Mitrajit ghorui (github: keyKyrios)
10+
* @see <a href="https://en.wikipedia.org/wiki/Chebyshev_iteration">Wikipedia
11+
* Page</a>
12+
*
13+
* This is an iterative method that requires the matrix A to be
14+
* symmetric positive definite (SPD).
15+
* It also requires knowledge of the minimum (lambdaMin) and maximum
16+
* (lambdaMax)
17+
* eigenvalues of the matrix A.
18+
*
19+
* The algorithm converges faster than simpler methods like Jacobi or
20+
* Gauss-Seidel
21+
* by using Chebyshev polynomials to optimize the update steps.
22+
*/
23+
public final class ChebyshevIteration {
24+
25+
/**
26+
* Private constructor to prevent instantiation of this utility class.
27+
*/
28+
private ChebyshevIteration() {
29+
}
30+
31+
/**
32+
* Solves the linear system Ax = b using Chebyshev Iteration.
33+
*
34+
* @param A A symmetric positive definite matrix.
35+
* @param b The vector 'b' in the equation Ax = b.
36+
* @param x0 An initial guess vector for 'x'.
37+
* @param lambdaMin The smallest eigenvalue of matrix A.
38+
* @param lambdaMax The largest eigenvalue of matrix A.
39+
* @param maxIterations The maximum number of iterations to perform.
40+
* @param tolerance The desired tolerance for convergence (e.g., 1e-10).
41+
* @return The solution vector 'x'.
42+
* @throws IllegalArgumentException if matrix/vector dimensions don't
43+
* match,
44+
* or if max/min eigenvalues are invalid.
45+
*/
46+
public static double[] solve(
47+
double[][] A,
48+
double[] b,
49+
double[] x0,
50+
double lambdaMin,
51+
double lambdaMax,
52+
int maxIterations,
53+
double tolerance
54+
) {
55+
validateInputs(A, b, x0, lambdaMin, lambdaMax);
56+
57+
int n = b.length;
58+
double[] x = Arrays.copyOf(x0, n);
59+
double[] r = vectorSubtract(b, matrixVectorMultiply(A, x));
60+
double[] p = new double[n];
61+
double alpha = 0.0;
62+
double beta = 0.0;
63+
double c = (lambdaMax - lambdaMin) / 2.0;
64+
double d = (lambdaMax + lambdaMin) / 2.0;
65+
66+
double initialResidualNorm = vectorNorm(r);
67+
if (initialResidualNorm < tolerance) {
68+
return x; // Already converged
69+
}
70+
71+
for (int k = 1; k <= maxIterations; k++) {
72+
if (k == 1) {
73+
alpha = 1.0 / d;
74+
p = Arrays.copyOf(r, n);
75+
} else {
76+
77+
// 1. Save the previous alpha
78+
double alphaPrev = alpha;
79+
80+
// 2. Calculate new beta and alpha using alphaPrev
81+
beta = (c * alphaPrev / 2.0) * (c * alphaPrev / 2.0);
82+
alpha = 1.0 / (d - beta / alphaPrev);
83+
84+
// 3. Use alphaPrev in the p update
85+
double betaOverAlphaPrev = beta / alphaPrev;
86+
double[] rScaled = vectorScale(p, betaOverAlphaPrev);
87+
p = vectorAdd(r, rScaled);
88+
89+
}
90+
91+
double[] pScaled = vectorScale(p, alpha);
92+
x = vectorAdd(x, pScaled);
93+
94+
// Re-calculate residual to avoid accumulating floating-point errors
95+
// Note: Some variants calculate r = r - alpha * A * p for
96+
// efficiency,
97+
// but this direct calculation is more stable against drift.
98+
r = vectorSubtract(b, matrixVectorMultiply(A, x));
99+
100+
if (vectorNorm(r) < tolerance) {
101+
break; // Converged
102+
}
103+
}
104+
return x;
105+
}
106+
107+
// --- Helper Methods for Linear Algebra ---
108+
private static void validateInputs(
109+
double[][] A,
110+
double[] b,
111+
double[] x0,
112+
double lambdaMin,
113+
double lambdaMax
114+
) {
115+
int n = b.length;
116+
if (n == 0) {
117+
throw new IllegalArgumentException("Vectors cannot be empty.");
118+
}
119+
if (A.length != n || A[0].length != n) {
120+
throw new IllegalArgumentException(
121+
"Matrix A must be square with dimensions n x n."
122+
);
123+
}
124+
if (x0.length != n) {
125+
throw new IllegalArgumentException(
126+
"Initial guess vector x0 must have length n."
127+
);
128+
}
129+
if (lambdaMin >= lambdaMax || lambdaMin <= 0) {
130+
throw new IllegalArgumentException(
131+
"Eigenvalues must satisfy 0 < lambdaMin < lambdaMax."
132+
);
133+
}
134+
}
135+
136+
/**
137+
* Computes y = Ax
138+
*/
139+
private static double[] matrixVectorMultiply(double[][] A, double[] x) {
140+
int n = A.length;
141+
double[] y = new double[n];
142+
for (int i = 0; i < n; i++) {
143+
double sum = 0.0;
144+
for (int j = 0; j < n; j++) {
145+
sum += A[i][j] * x[j];
146+
}
147+
y[i] = sum;
148+
}
149+
return y;
150+
}
151+
152+
/**
153+
* Computes c = a + b
154+
*/
155+
private static double[] vectorAdd(double[] a, double[] b) {
156+
int n = a.length;
157+
double[] c = new double[n];
158+
for (int i = 0; i < n; i++) {
159+
c[i] = a[i] + b[i];
160+
}
161+
return c;
162+
}
163+
164+
/**
165+
* Computes c = a - b
166+
*/
167+
private static double[] vectorSubtract(double[] a, double[] b) {
168+
int n = a.length;
169+
double[] c = new double[n];
170+
for (int i = 0; i < n; i++) {
171+
c[i] = a[i] - b[i];
172+
}
173+
return c;
174+
}
175+
176+
/**
177+
* Computes c = a * scalar
178+
*/
179+
private static double[] vectorScale(double[] a, double scalar) {
180+
int n = a.length;
181+
double[] c = new double[n];
182+
for (int i = 0; i < n; i++) {
183+
c[i] = a[i] * scalar;
184+
}
185+
return c;
186+
}
187+
188+
/**
189+
* Computes the L2 norm (Euclidean norm) of a vector
190+
*/
191+
private static double vectorNorm(double[] a) {
192+
double sumOfSquares = 0.0;
193+
for (double val : a) {
194+
sumOfSquares += val * val;
195+
}
196+
return Math.sqrt(sumOfSquares);
197+
}
198+
}
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
package com.thealgorithms.maths;
2+
3+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
4+
import static org.junit.jupiter.api.Assertions.assertThrows;
5+
6+
import org.junit.jupiter.api.Test;
7+
8+
class ChebyshevIterationTest {
9+
10+
// --- Constants for testSolveSimple2x2System ---
11+
private static final double M1_A11 = 4.0;
12+
private static final double M1_A12 = 1.0;
13+
private static final double M1_A21 = 1.0;
14+
private static final double M1_A22 = 3.0;
15+
private static final double[][] M1_A = { { M1_A11, M1_A12 }, { M1_A21, M1_A22 } };
16+
17+
private static final double M1_B1 = 1.0;
18+
private static final double M1_B2 = 2.0;
19+
private static final double[] M1_B = { M1_B1, M1_B2 };
20+
private static final double[] M1_X0 = { 0.0, 0.0 };
21+
22+
// Eigenvalues are (7 +/- sqrt(5)) / 2
23+
private static final double M1_LAMBDA_MIN = (7.0 - Math.sqrt(5.0)) / 2.0;
24+
private static final double M1_LAMBDA_MAX = (7.0 + Math.sqrt(5.0)) / 2.0;
25+
private static final double M1_EXPECTED_X1 = 1.0 / 11.0;
26+
private static final double M1_EXPECTED_X2 = 7.0 / 11.0;
27+
private static final double[] M1_EXPECTED = { M1_EXPECTED_X1, M1_EXPECTED_X2 };
28+
29+
// --- Constants for testSolve3x3System ---
30+
private static final double[][] M2_A = { { 5.0, 0.0, 0.0 }, { 0.0, 2.0, 0.0 }, { 0.0, 0.0, 8.0 } };
31+
private static final double[] M2_B = { 10.0, -4.0, 24.0 };
32+
private static final double[] M2_X0 = { 0.0, 0.0, 0.0 };
33+
private static final double[] M2_EXPECTED = { 2.0, -2.0, 3.0 };
34+
private static final double M2_LAMBDA_MIN = 2.0;
35+
private static final double M2_LAMBDA_MAX = 8.0;
36+
37+
// --- Constants for testAlreadyConverged ---
38+
private static final double M3_LAMBDA_MIN = 2.38;
39+
private static final double M3_LAMBDA_MAX = 4.62;
40+
41+
// --- General Constants ---
42+
private static final int MAX_ITERATIONS = 100;
43+
private static final double TOLERANCE = 1e-10;
44+
private static final double ASSERT_TOLERANCE = 1e-9;
45+
private static final int TEST_ITERATIONS = 10;
46+
private static final double TEST_TOLERANCE = 1e-5;
47+
48+
@Test
49+
void testSolveSimple2x2System() {
50+
double[] solution = ChebyshevIteration.solve(
51+
M1_A,
52+
M1_B,
53+
M1_X0,
54+
M1_LAMBDA_MIN,
55+
M1_LAMBDA_MAX,
56+
MAX_ITERATIONS,
57+
TOLERANCE
58+
);
59+
assertArrayEquals(M1_EXPECTED, solution, ASSERT_TOLERANCE);
60+
}
61+
62+
@Test
63+
void testSolve3x3System() {
64+
double[] solution = ChebyshevIteration.solve(
65+
M2_A,
66+
M2_B,
67+
M2_X0,
68+
M2_LAMBDA_MIN,
69+
M2_LAMBDA_MAX,
70+
MAX_ITERATIONS,
71+
TOLERANCE
72+
);
73+
assertArrayEquals(M2_EXPECTED, solution, ASSERT_TOLERANCE);
74+
}
75+
76+
@Test
77+
void testAlreadyConverged() {
78+
// Test case where the initial guess is already the solution
79+
double[] solution = ChebyshevIteration.solve(
80+
M1_A,
81+
M1_B,
82+
M1_EXPECTED, // Use expected solution as initial guess
83+
M3_LAMBDA_MIN, // Use approximate eigenvalues
84+
M3_LAMBDA_MAX,
85+
MAX_ITERATIONS,
86+
TOLERANCE
87+
);
88+
assertArrayEquals(M1_EXPECTED, solution, ASSERT_TOLERANCE);
89+
}
90+
91+
@Test
92+
void testInvalidEigenvalues() {
93+
double[][] A = { { 1.0, 0.0 }, { 0.0, 1.0 } };
94+
double[] b = { 1.0, 1.0 };
95+
double[] x0 = { 0.0, 0.0 };
96+
97+
// lambdaMin >= lambdaMax
98+
assertThrows(
99+
IllegalArgumentException.class,
100+
() ->
101+
ChebyshevIteration.solve(
102+
A,
103+
b,
104+
x0,
105+
2.0,
106+
1.0,
107+
TEST_ITERATIONS,
108+
TEST_TOLERANCE
109+
)
110+
);
111+
// lambdaMin <= 0
112+
assertThrows(
113+
IllegalArgumentException.class,
114+
() ->
115+
ChebyshevIteration.solve(
116+
A,
117+
b,
118+
x0,
119+
0.0,
120+
2.0,
121+
TEST_ITERATIONS,
122+
TEST_TOLERANCE
123+
)
124+
);
125+
}
126+
127+
@Test
128+
void testMismatchedDimensions() {
129+
double[][] A = { { 1.0, 0.0 }, { 0.0, 1.0 } };
130+
double[] b = { 1.0, 1.0, 1.0 }; // b.length = 3
131+
double[] x0 = { 0.0, 0.0 }; // x0.length = 2
132+
133+
assertThrows(
134+
IllegalArgumentException.class,
135+
() ->
136+
ChebyshevIteration.solve(
137+
A,
138+
b,
139+
x0,
140+
0.5,
141+
1.5,
142+
TEST_ITERATIONS,
143+
TEST_TOLERANCE
144+
)
145+
);
146+
}
147+
148+
@Test
149+
void testNonSquareMatrix() {
150+
double[][] A = { { 1.0, 0.0, 0.0 }, { 0.0, 1.0, 0.0 } }; // 2x3 matrix
151+
double[] b = { 1.0, 1.0 };
152+
double[] x0 = { 0.0, 0.0 };
153+
154+
assertThrows(
155+
IllegalArgumentException.class,
156+
() ->
157+
ChebyshevIteration.solve(
158+
A,
159+
b,
160+
x0,
161+
0.5,
162+
1.5,
163+
TEST_ITERATIONS,
164+
TEST_TOLERANCE
165+
)
166+
);
167+
}
168+
}

0 commit comments

Comments
 (0)