66 * Implements the Chebyshev Iteration method for solving a system of linear
77 * equations (Ax = b).
88 *
9- * @author Mitrajit ghorui (github: keyKyrios)
10- * @see <a href="https://en.wikipedia.org/wiki/Chebyshev_iteration">Wikipedia
11- * Page</a>
9+ * This iterative method requires:
10+ * - Matrix A to be symmetric positive definite (SPD)
11+ * - Knowledge of minimum (lambdaMin) and maximum (lambdaMax) eigenvalues
1212 *
13- * This is an iterative method that requires the matrix A to be
14- * symmetric positive definite (SPD).
15- * It also requires knowledge of the minimum (lambdaMin) and maximum
16- * (lambdaMax)
17- * eigenvalues of the matrix A.
13+ * Reference: https://en.wikipedia.org/wiki/Chebyshev_iteration
1814 *
19- * The algorithm converges faster than simpler methods like Jacobi or
20- * Gauss-Seidel
21- * by using Chebyshev polynomials to optimize the update steps.
15+ * Author: Mitrajit Ghorui (github: keyKyrios)
2216 */
2317public final class ChebyshevIteration {
2418
25- /**
26- * Private constructor to prevent instantiation of this utility class.
27- */
2819 private ChebyshevIteration () {
2920 }
3021
3122 /**
32- * Solves the linear system Ax = b using Chebyshev Iteration.
23+ * Solves Ax = b using Chebyshev Iteration.
3324 *
34- * @param a A symmetric positive definite matrix.
35- * @param b The vector 'b' in the equation Ax = b.
36- * @param x0 An initial guess vector for 'x'.
37- * @param lambdaMin The smallest eigenvalue of matrix A.
38- * @param lambdaMax The largest eigenvalue of matrix A.
39- * @param maxIterations The maximum number of iterations to perform.
40- * @param tolerance The desired tolerance for convergence (e.g., 1e-10).
41- * @return The solution vector 'x'.
42- * @throws IllegalArgumentException if matrix/vector dimensions don't
43- * match,
44- * or if max/min eigenvalues are invalid.
25+ * @param A SPD matrix
26+ * @param b vector b
27+ * @param x0 initial guess
28+ * @param lambdaMin minimum eigenvalue
29+ * @param lambdaMax maximum eigenvalue
30+ * @param maxIterations maximum iterations
31+ * @param tolerance convergence tolerance
32+ * @return solution vector x
4533 */
46- public static double [] solve (double [][] a , double [] b , double [] x0 , double lambdaMin , double lambdaMax , int maxIterations , double tolerance ) {
47- validateInputs (a , b , x0 , lambdaMin , lambdaMax );
34+ public static double [] solve (double [][] A , double [] b , double [] x0 ,
35+ double lambdaMin , double lambdaMax ,
36+ int maxIterations , double tolerance ) {
37+ validateInputs (A , b , x0 , lambdaMin , lambdaMax );
4838
4939 int n = b .length ;
5040 double [] x = Arrays .copyOf (x0 , n );
51- double [] r = vectorSubtract (b , matrixVectorMultiply (a , x )); // Use `a`
41+ double [] r = vectorSubtract (b , matrixVectorMultiply (A , x ));
5242 double [] p = new double [n ];
5343 double alpha = 0.0 ;
5444 double beta = 0.0 ;
@@ -66,105 +56,71 @@ public static double[] solve(double[][] a, double[] b, double[] x0, double lambd
6656 p = Arrays .copyOf (r , n );
6757 } else {
6858 double alphaPrev = alpha ;
69-
70- beta = (c * alphaPrev / 2.0 ) * (c * alphaPrev / 2.0 );
59+ beta = Math .pow (c * alphaPrev / 2.0 , 2 );
7160 alpha = 1.0 / (d - beta / alphaPrev );
72-
73- double betaOverAlphaPrev = beta / alphaPrev ;
74- double [] rScaled = vectorScale (p , betaOverAlphaPrev );
75- p = vectorAdd (r , rScaled );
61+ p = vectorAdd (r , vectorScale (p , beta / alphaPrev ));
7662 }
7763
78- double [] pScaled = vectorScale (p , alpha );
79- x = vectorAdd (x , pScaled );
80-
81- // Re-calculate residual to avoid accumulating floating-point errors
82- r = vectorSubtract (b , matrixVectorMultiply (a , x )); // Use `a`
64+ x = vectorAdd (x , vectorScale (p , alpha ));
65+ r = vectorSubtract (b , matrixVectorMultiply (A , x ));
8366
8467 if (vectorNorm (r ) < tolerance ) {
85- break ; // Converged
68+ break ;
8669 }
8770 }
71+
8872 return x ;
8973 }
9074
91- // --- Helper Methods for Linear Algebra ---
92- private static void validateInputs ( double [][] a , double [] b , double [] x0 , double lambdaMin , double lambdaMax ) {
75+ private static void validateInputs ( double [][] A , double [] b , double [] x0 ,
76+ double lambdaMin , double lambdaMax ) {
9377 int n = b .length ;
94- if (n == 0 ) {
95- throw new IllegalArgumentException ("Vectors cannot be empty." );
96- }
97- if (a .length != n || a [0 ].length != n ) { // Use `a`
98- throw new IllegalArgumentException ("Matrix A must be square with dimensions n x n." );
99- }
100- if (x0 .length != n ) {
78+ if (n == 0 ) throw new IllegalArgumentException ("Vectors cannot be empty." );
79+ if (A .length != n || A [0 ].length != n )
80+ throw new IllegalArgumentException ("Matrix A must be square (n x n)." );
81+ if (x0 .length != n )
10182 throw new IllegalArgumentException ("Initial guess vector x0 must have length n." );
102- }
103- if (lambdaMin >= lambdaMax || lambdaMin <= 0 ) {
83+ if (lambdaMin >= lambdaMax || lambdaMin <= 0 )
10484 throw new IllegalArgumentException ("Eigenvalues must satisfy 0 < lambdaMin < lambdaMax." );
105- }
10685 }
10786
108- /**
109- * Computes y = Ax
110- */
111- private static double [] matrixVectorMultiply (double [][] a , double [] x ) {
112- int n = a .length ; // Use `a`
87+ private static double [] matrixVectorMultiply (double [][] A , double [] x ) {
88+ int n = A .length ;
11389 double [] y = new double [n ];
11490 for (int i = 0 ; i < n ; i ++) {
11591 double sum = 0.0 ;
11692 for (int j = 0 ; j < n ; j ++) {
117- sum += a [i ][j ] * x [j ]; // Use `a`
93+ sum += A [i ][j ] * x [j ];
11894 }
11995 y [i ] = sum ;
12096 }
12197 return y ;
12298 }
12399
124- /**
125- * Computes c = a + b
126- */
127100 private static double [] vectorAdd (double [] a , double [] b ) {
128101 int n = a .length ;
129102 double [] c = new double [n ];
130- for (int i = 0 ; i < n ; i ++) {
131- c [i ] = a [i ] + b [i ];
132- }
103+ for (int i = 0 ; i < n ; i ++) c [i ] = a [i ] + b [i ];
133104 return c ;
134105 }
135106
136- /**
137- * Computes c = a - b
138- */
139107 private static double [] vectorSubtract (double [] a , double [] b ) {
140108 int n = a .length ;
141109 double [] c = new double [n ];
142- for (int i = 0 ; i < n ; i ++) {
143- c [i ] = a [i ] - b [i ];
144- }
110+ for (int i = 0 ; i < n ; i ++) c [i ] = a [i ] - b [i ];
145111 return c ;
146112 }
147113
148- /**
149- * Computes c = a * scalar
150- */
151114 private static double [] vectorScale (double [] a , double scalar ) {
152115 int n = a .length ;
153116 double [] c = new double [n ];
154- for (int i = 0 ; i < n ; i ++) {
155- c [i ] = a [i ] * scalar ;
156- }
117+ for (int i = 0 ; i < n ; i ++) c [i ] = a [i ] * scalar ;
157118 return c ;
158119 }
159120
160- /**
161- * Computes the L2 norm (Euclidean norm) of a vector
162- */
163121 private static double vectorNorm (double [] a ) {
164- double sumOfSquares = 0.0 ;
165- for (double val : a ) {
166- sumOfSquares += val * val ;
167- }
168- return Math .sqrt (sumOfSquares );
122+ double sum = 0.0 ;
123+ for (double val : a ) sum += val * val ;
124+ return Math .sqrt (sum );
169125 }
170126}
0 commit comments