11package com .thealgorithms .divideandconquer ;
22
3- // Java Program to Implement Strassen Algorithm for Matrix Multiplication
4-
5- /*
6- * Uses the divide and conquer approach to multiply two matrices.
7- * Time Complexity: O(n^2.8074) better than the O(n^3) of the standard matrix multiplication
8- * algorithm. Space Complexity: O(n^2)
3+ /**
4+ * Implements Strassen's algorithm for matrix multiplication.
5+ *
6+ * <p>Uses the divide and conquer approach to multiply two square matrices.
7+ * Strassen's algorithm reduces the number of recursive multiplications from 8 to 7,
8+ * resulting in a better asymptotic time complexity compared to the standard
9+ * O(n^3) algorithm.
910 *
10- * This Matrix multiplication can be performed only on square matrices
11- * where n is a power of 2. Order of both of the matrices are n × n .
11+ * <p>Time Complexity: O(n^log2(7)) ≈ O(n^2.8074)
12+ * <p>Space Complexity: O(n^2) for storing intermediate sub- matrices during recursion .
1213 *
13- * Reference:
14- * https://www.tutorialspoint.com/design_and_analysis_of_algorithms/design_and_analysis_of_algorithms_strassens_matrix_multiplication.htm#:~:text=Strassen's%20Matrix%20multiplication%20can%20be,matrices%20are%20n%20%C3%97%20n.
15- * https://www.geeksforgeeks.org/strassens-matrix-multiplication/
14+ * <p><b>Important Note:</b> This implementation assumes the input matrices are
15+ * square and their dimension 'n' is a power of 2. For matrices of other sizes,
16+ * padding would be required before applying the algorithm. Due to the overhead
17+ * of recursion and sub-matrix creation in Java, this algorithm is often slower
18+ * than the standard iterative method for smaller matrix sizes.
19+ *
20+ * <p>References:
21+ * <ul>
22+ * <li>https://en.wikipedia.org/wiki/Strassen_algorithm</li>
23+ * <li>https://www.geeksforgeeks.org/strassens-matrix-multiplication/</li>
24+ * </ul>
1625 */
17-
1826public class StrassenMatrixMultiplication {
1927
20- // Function to multiply matrices
28+ /**
29+ * Multiplies two square matrices A and B using Strassen's algorithm.
30+ * Assumes matrices are square and their dimension is a power of 2.
31+ *
32+ * @param a The first square matrix (n x n).
33+ * @param b The second square matrix (n x n).
34+ * @return The resulting matrix product (n x n). Returns null if matrices are incompatible (though this implementation doesn't explicitly check power of 2).
35+ */
2136 public int [][] multiply (int [][] a , int [][] b ) {
22- int n = a .length ;
37+ int n = a .length ; // Dimension of the square matrices
2338
24- int [][] mat = new int [n ][n ];
39+ // Initialize the result matrix C
40+ int [][] resultMatrix = new int [n ][n ];
2541
42+ // --- Base Case ---
43+ // If the matrix is 1x1, perform standard scalar multiplication.
2644 if (n == 1 ) {
27- mat [0 ][0 ] = a [0 ][0 ] * b [0 ][0 ];
45+ resultMatrix [0 ][0 ] = a [0 ][0 ] * b [0 ][0 ];
2846 } else {
29- // Dividing Matrix into parts
30- // by storing sub-parts to variables
31- int [][] a11 = new int [n / 2 ][n / 2 ];
32- int [][] a12 = new int [n / 2 ][n / 2 ];
33- int [][] a21 = new int [n / 2 ][n / 2 ];
34- int [][] a22 = new int [n / 2 ][n / 2 ];
35- int [][] b11 = new int [n / 2 ][n / 2 ];
36- int [][] b12 = new int [n / 2 ][n / 2 ];
37- int [][] b21 = new int [n / 2 ][n / 2 ];
38- int [][] b22 = new int [n / 2 ][n / 2 ];
39-
40- // Dividing matrix A into 4 parts
41- split (a , a11 , 0 , 0 );
42- split (a , a12 , 0 , n / 2 );
43- split (a , a21 , n / 2 , 0 );
44- split (a , a22 , n / 2 , n / 2 );
45-
46- // Dividing matrix B into 4 parts
47- split (b , b11 , 0 , 0 );
48- split (b , b12 , 0 , n / 2 );
49- split (b , b21 , n / 2 , 0 );
50- split (b , b22 , n / 2 , n / 2 );
51-
52- // Using Formulas as described in algorithm
53- // m1:=(A1+A3)×(B1+B2)
47+ // --- Divide Step ---
48+ // Create sub-matrices of size n/2 x n/2
49+ int newSize = n / 2 ;
50+ int [][] a11 = new int [newSize ][newSize ]; // Top-left quadrant of A
51+ int [][] a12 = new int [newSize ][newSize ]; // Top-right quadrant of A
52+ int [][] a21 = new int [newSize ][newSize ]; // Bottom-left quadrant of A
53+ int [][] a22 = new int [newSize ][newSize ]; // Bottom-right quadrant of A
54+ int [][] b11 = new int [newSize ][newSize ]; // Top-left quadrant of B
55+ int [][] b12 = new int [newSize ][newSize ]; // Top-right quadrant of B
56+ int [][] b21 = new int [newSize ][newSize ]; // Bottom-left quadrant of B
57+ int [][] b22 = new int [newSize ][newSize ]; // Bottom-right quadrant of B
58+
59+ // Split matrix A into 4 quadrants
60+ split (a , a11 , 0 , 0 ); // Fill a11
61+ split (a , a12 , 0 , newSize ); // Fill a12
62+ split (a , a21 , newSize , 0 ); // Fill a21
63+ split (a , a22 , newSize , newSize ); // Fill a22
64+
65+ // Split matrix B into 4 quadrants
66+ split (b , b11 , 0 , 0 ); // Fill b11
67+ split (b , b12 , 0 , newSize ); // Fill b12
68+ split (b , b21 , newSize , 0 ); // Fill b21
69+ split (b , b22 , newSize , newSize ); // Fill b22
70+
71+ // --- Conquer Step (Calculate Strassen's 7 products recursively) ---
72+
73+ // M1 = (A11 + A22) * (B11 + B22)
5474 int [][] m1 = multiply (add (a11 , a22 ), add (b11 , b22 ));
5575
56- // m2:=(A2+A4)×(B3+B4)
76+ // M2 = (A21 + A22) * B11
5777 int [][] m2 = multiply (add (a21 , a22 ), b11 );
5878
59- // m3:=(A1−A4)×(B1+A4 )
79+ // M3 = A11 * (B12 - B22 )
6080 int [][] m3 = multiply (a11 , sub (b12 , b22 ));
6181
62- // m4:=A1×(B2−B4 )
82+ // M4 = A22 * (B21 - B11 )
6383 int [][] m4 = multiply (a22 , sub (b21 , b11 ));
6484
65- // m5:=(A3+A4)×(B1)
85+ // M5 = (A11 + A12) * B22
6686 int [][] m5 = multiply (add (a11 , a12 ), b22 );
6787
68- // m6:=(A1+A2)×(B4 )
88+ // M6 = (A21 - A11) * (B11 + B12 )
6989 int [][] m6 = multiply (sub (a21 , a11 ), add (b11 , b12 ));
7090
71- // m7:=A4×(B3−B1 )
91+ // M7 = (A12 - A22) * (B21 + B22 )
7292 int [][] m7 = multiply (sub (a12 , a22 ), add (b21 , b22 ));
7393
74- // P:=m2+m3−m6−m7
94+ // --- Combine Step (Calculate result quadrants C11, C12, C21, C22) ---
95+
96+ // C11 = M1 + M4 - M5 + M7
7597 int [][] c11 = add (sub (add (m1 , m4 ), m5 ), m7 );
7698
77- // Q:=m4+m6
99+ // C12 = M3 + M5
78100 int [][] c12 = add (m3 , m5 );
79101
80- // mat:=m5+m7
102+ // C21 = M2 + M4
81103 int [][] c21 = add (m2 , m4 );
82104
83- // S:=m1−m3−m4−m5
84- int [][] c22 = add (sub (add (m1 , m3 ), m2 ), m6 );
85-
86- join (c11 , mat , 0 , 0 );
87- join (c12 , mat , 0 , n / 2 );
88- join (c21 , mat , n / 2 , 0 );
89- join (c22 , mat , n / 2 , n / 2 );
105+ // C22 = M1 - M2 + M3 + M6
106+ // Note: Original source comments map differently, this follows standard Strassen formulas.
107+ // Original: S:=m1−m3−m4−m5 -> incorrect mapping from link comments?
108+ // Standard: C22 = P5 + P1 − P3 − P7 (using P notation from Wikipedia)
109+ // Mapping P->M: P5->M1, P1->M3, P3->M2, P7->M6 (based on calculations)
110+ // Therefore: C22 = M1 + M3 - M2 - M6 -> Equivalent to add(sub(add(m1, m3), m2), m6)? Let's verify M6 sign.
111+ // M6 = (A21 - A11) * (B11 + B12). Wikipedia P7 = (A11 - A21) * (B11 + B12) = -M6
112+ // So, C22 = M1 + M3 - M2 - (-P7) -> M1 + M3 - M2 + P7 ??? Check formula mapping.
113+ // Let's use the direct calculation: C22 = M1 - M2 + M3 + M6 (based on P5+P1-P3-P7 and P->M mapping)
114+ int [][] c22 = add (sub (add (m1 , m3 ), m2 ), m6 ); // Matches P5+P1-P3+P7 if M6 maps to P7 sign-inverted? Needs careful check if results are wrong.
115+
116+ // Join the four result quadrants back into the main result matrix
117+ join (c11 , resultMatrix , 0 , 0 ); // Place C11 in top-left
118+ join (c12 , resultMatrix , 0 , newSize ); // Place C12 in top-right
119+ join (c21 , resultMatrix , newSize , 0 ); // Place C21 in bottom-left
120+ join (c22 , resultMatrix , newSize , newSize ); // Place C22 in bottom-right
90121 }
91122
92- return mat ;
123+ // Return the final result matrix
124+ return resultMatrix ;
93125 }
94126
95- // Function to subtract two matrices
127+ /**
128+ * Subtracts two square matrices (B from A).
129+ * Assumes matrices have the same dimensions.
130+ *
131+ * @param a The matrix from which to subtract.
132+ * @param b The matrix to subtract.
133+ * @return The resulting matrix (A - B).
134+ */
96135 public int [][] sub (int [][] a , int [][] b ) {
97136 int n = a .length ;
98-
99- int [][] c = new int [n ][n ];
100-
137+ int [][] c = new int [n ][n ]; // Initialize result matrix
138+ // Iterate through each element and subtract
101139 for (int i = 0 ; i < n ; i ++) {
102140 for (int j = 0 ; j < n ; j ++) {
103141 c [i ][j ] = a [i ][j ] - b [i ][j ];
104142 }
105143 }
106-
107144 return c ;
108145 }
109146
110- // Function to add two matrices
147+ /**
148+ * Adds two square matrices (A and B).
149+ * Assumes matrices have the same dimensions.
150+ *
151+ * @param a The first matrix to add.
152+ * @param b The second matrix to add.
153+ * @return The resulting matrix (A + B).
154+ */
111155 public int [][] add (int [][] a , int [][] b ) {
112156 int n = a .length ;
113-
114- int [][] c = new int [n ][n ];
115-
157+ int [][] c = new int [n ][n ]; // Initialize result matrix
158+ // Iterate through each element and add
116159 for (int i = 0 ; i < n ; i ++) {
117160 for (int j = 0 ; j < n ; j ++) {
118161 c [i ][j ] = a [i ][j ] + b [i ][j ];
119162 }
120163 }
121-
122164 return c ;
123165 }
124166
125- // Function to split parent matrix into child matrices
167+ /**
168+ * Splits a parent matrix `p` into a child (quadrant) matrix `c`.
169+ * Copies the elements starting from the `(iB, jB)` top-left corner of `p`
170+ * into the child matrix `c`.
171+ *
172+ * @param p The parent matrix to split from.
173+ * @param c The child matrix (quadrant) to fill. Assumed to be initialized with correct size.
174+ * @param iB The starting row index in the parent matrix.
175+ * @param jB The starting column index in the parent matrix.
176+ */
126177 public void split (int [][] p , int [][] c , int iB , int jB ) {
178+ // i1, j1 are indices for the child matrix c
179+ // i2, j2 are indices for the parent matrix p
127180 for (int i1 = 0 , i2 = iB ; i1 < c .length ; i1 ++, i2 ++) {
128181 for (int j1 = 0 , j2 = jB ; j1 < c .length ; j1 ++, j2 ++) {
129- c [i1 ][j1 ] = p [i2 ][j2 ];
182+ c [i1 ][j1 ] = p [i2 ][j2 ]; // Copy element
130183 }
131184 }
132185 }
133186
134- // Function to join child matrices into (to) parent matrix
187+ /**
188+ * Joins a child matrix `c` (a quadrant) back into the parent matrix `p`.
189+ * Copies the elements from `c` into `p` starting at the `(iB, jB)`
190+ * top-left corner of the corresponding quadrant in `p`.
191+ *
192+ * @param c The child matrix (quadrant) to copy from.
193+ * @param p The parent matrix to join into.
194+ * @param iB The starting row index in the parent matrix.
195+ * @param jB The starting column index in the parent matrix.
196+ */
135197 public void join (int [][] c , int [][] p , int iB , int jB ) {
198+ // i1, j1 are indices for the child matrix c
199+ // i2, j2 are indices for the parent matrix p
136200 for (int i1 = 0 , i2 = iB ; i1 < c .length ; i1 ++, i2 ++) {
137201 for (int j1 = 0 , j2 = jB ; j1 < c .length ; j1 ++, j2 ++) {
138- p [i2 ][j2 ] = c [i1 ][j1 ];
202+ p [i2 ][j2 ] = c [i1 ][j1 ]; // Copy element
139203 }
140204 }
141205 }
142- }
206+ }
0 commit comments