1+ package com .thealgorithms .maths ;
2+
3+ import java .util .Arrays ;
4+
5+ /**
6+ * Implements the Chebyshev Iteration method for solving a system of linear
7+ * equations (Ax = b).
8+ *
9+ * @author Mitrajit ghorui (github: keyKyrios)
10+ * @see <a href="https://en.wikipedia.org/wiki/Chebyshev_iteration">Wikipedia
11+ * Page</a>
12+ *
13+ * This is an iterative method that requires the matrix A to be
14+ * symmetric positive definite (SPD).
15+ * It also requires knowledge of the minimum (lambdaMin) and maximum
16+ * (lambdaMax)
17+ * eigenvalues of the matrix A.
18+ *
19+ * The algorithm converges faster than simpler methods like Jacobi or
20+ * Gauss-Seidel
21+ * by using Chebyshev polynomials to optimize the update steps.
22+ */
23+ public final class ChebyshevIteration {
24+
25+ /**
26+ * Private constructor to prevent instantiation of this utility class.
27+ */
28+ private ChebyshevIteration () {
29+ }
30+
31+ /**
32+ * Solves the linear system Ax = b using Chebyshev Iteration.
33+ *
34+ * @param A A symmetric positive definite matrix.
35+ * @param b The vector 'b' in the equation Ax = b.
36+ * @param x0 An initial guess vector for 'x'.
37+ * @param lambdaMin The smallest eigenvalue of matrix A.
38+ * @param lambdaMax The largest eigenvalue of matrix A.
39+ * @param maxIterations The maximum number of iterations to perform.
40+ * @param tolerance The desired tolerance for convergence (e.g., 1e-10).
41+ * @return The solution vector 'x'.
42+ * @throws IllegalArgumentException if matrix/vector dimensions don't
43+ * match,
44+ * or if max/min eigenvalues are invalid.
45+ */
46+ public static double [] solve (
47+ double [][] A ,
48+ double [] b ,
49+ double [] x0 ,
50+ double lambdaMin ,
51+ double lambdaMax ,
52+ int maxIterations ,
53+ double tolerance
54+ ) {
55+ validateInputs (A , b , x0 , lambdaMin , lambdaMax );
56+
57+ int n = b .length ;
58+ double [] x = Arrays .copyOf (x0 , n );
59+ double [] r = vectorSubtract (b , matrixVectorMultiply (A , x ));
60+ double [] p = new double [n ];
61+ double alpha = 0.0 ;
62+ double beta = 0.0 ;
63+ double c = (lambdaMax - lambdaMin ) / 2.0 ;
64+ double d = (lambdaMax + lambdaMin ) / 2.0 ;
65+
66+ double initialResidualNorm = vectorNorm (r );
67+ if (initialResidualNorm < tolerance ) {
68+ return x ; // Already converged
69+ }
70+
71+ for (int k = 1 ; k <= maxIterations ; k ++) {
72+ if (k == 1 ) {
73+ alpha = 1.0 / d ;
74+ p = Arrays .copyOf (r , n );
75+ } else {
76+
77+ // 1. Save the previous alpha
78+ double alphaPrev = alpha ;
79+
80+ // 2. Calculate new beta and alpha using alphaPrev
81+ beta = (c * alphaPrev / 2.0 ) * (c * alphaPrev / 2.0 );
82+ alpha = 1.0 / (d - beta / alphaPrev );
83+
84+ // 3. Use alphaPrev in the p update
85+ double betaOverAlphaPrev = beta / alphaPrev ;
86+ double [] rScaled = vectorScale (p , betaOverAlphaPrev );
87+ p = vectorAdd (r , rScaled );
88+
89+ }
90+
91+ double [] pScaled = vectorScale (p , alpha );
92+ x = vectorAdd (x , pScaled );
93+
94+ // Re-calculate residual to avoid accumulating floating-point errors
95+ // Note: Some variants calculate r = r - alpha * A * p for
96+ // efficiency,
97+ // but this direct calculation is more stable against drift.
98+ r = vectorSubtract (b , matrixVectorMultiply (A , x ));
99+
100+ if (vectorNorm (r ) < tolerance ) {
101+ break ; // Converged
102+ }
103+ }
104+ return x ;
105+ }
106+
107+ // --- Helper Methods for Linear Algebra ---
108+ private static void validateInputs (
109+ double [][] A ,
110+ double [] b ,
111+ double [] x0 ,
112+ double lambdaMin ,
113+ double lambdaMax
114+ ) {
115+ int n = b .length ;
116+ if (n == 0 ) {
117+ throw new IllegalArgumentException ("Vectors cannot be empty." );
118+ }
119+ if (A .length != n || A [0 ].length != n ) {
120+ throw new IllegalArgumentException (
121+ "Matrix A must be square with dimensions n x n."
122+ );
123+ }
124+ if (x0 .length != n ) {
125+ throw new IllegalArgumentException (
126+ "Initial guess vector x0 must have length n."
127+ );
128+ }
129+ if (lambdaMin >= lambdaMax || lambdaMin <= 0 ) {
130+ throw new IllegalArgumentException (
131+ "Eigenvalues must satisfy 0 < lambdaMin < lambdaMax."
132+ );
133+ }
134+ }
135+
136+ /**
137+ * Computes y = Ax
138+ */
139+ private static double [] matrixVectorMultiply (double [][] A , double [] x ) {
140+ int n = A .length ;
141+ double [] y = new double [n ];
142+ for (int i = 0 ; i < n ; i ++) {
143+ double sum = 0.0 ;
144+ for (int j = 0 ; j < n ; j ++) {
145+ sum += A [i ][j ] * x [j ];
146+ }
147+ y [i ] = sum ;
148+ }
149+ return y ;
150+ }
151+
152+ /**
153+ * Computes c = a + b
154+ */
155+ private static double [] vectorAdd (double [] a , double [] b ) {
156+ int n = a .length ;
157+ double [] c = new double [n ];
158+ for (int i = 0 ; i < n ; i ++) {
159+ c [i ] = a [i ] + b [i ];
160+ }
161+ return c ;
162+ }
163+
164+ /**
165+ * Computes c = a - b
166+ */
167+ private static double [] vectorSubtract (double [] a , double [] b ) {
168+ int n = a .length ;
169+ double [] c = new double [n ];
170+ for (int i = 0 ; i < n ; i ++) {
171+ c [i ] = a [i ] - b [i ];
172+ }
173+ return c ;
174+ }
175+
176+ /**
177+ * Computes c = a * scalar
178+ */
179+ private static double [] vectorScale (double [] a , double scalar ) {
180+ int n = a .length ;
181+ double [] c = new double [n ];
182+ for (int i = 0 ; i < n ; i ++) {
183+ c [i ] = a [i ] * scalar ;
184+ }
185+ return c ;
186+ }
187+
188+ /**
189+ * Computes the L2 norm (Euclidean norm) of a vector
190+ */
191+ private static double vectorNorm (double [] a ) {
192+ double sumOfSquares = 0.0 ;
193+ for (double val : a ) {
194+ sumOfSquares += val * val ;
195+ }
196+ return Math .sqrt (sumOfSquares );
197+ }
198+ }
0 commit comments