Skip to content

Commit d4af9fc

Browse files
feat: Add Strassen matrix multiplication algorithm and tests
1 parent e21aee8 commit d4af9fc

File tree

2 files changed

+394
-0
lines changed

2 files changed

+394
-0
lines changed
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
package com.thealgorithms.matrix;
2+
3+
/**
4+
* This class provides a method to perform matrix multiplication using
5+
* Strassen's algorithm.
6+
*
7+
* <p>
8+
* Strassen's algorithm is a divide-and-conquer algorithm that is
9+
* asymptotically faster than the standard O(n^3) matrix multiplication.
10+
* It performs 7 recursive multiplications of sub-matrices of size n/2
11+
* instead of the 8 required by the standard recursive method.
12+
*
13+
* <p>
14+
* For more details:
15+
* https://en.wikipedia.org/wiki/Strassen_algorithm
16+
*
17+
* <p>
18+
* Time Complexity: O(n^log2(7)) ≈ O(n^2.807)
19+
*
20+
* <p>
21+
* Space Complexity: O(n^2) – for storing intermediate and result matrices.
22+
*
23+
* <p>
24+
* Note: Due to the high overhead of recursion and sub-matrix creation in
25+
* Java, this algorithm is often slower than the standard O(n^3)
26+
* {@link MatrixMultiplication} for smaller matrices. A threshold is used
27+
* to switch to the standard algorithm for small matrices.
28+
*
29+
* @author @ITZ-NIHALPATEL
30+
*
31+
*/
32+
public final class StrassenMatrixMultiplication {
33+
34+
/**
35+
* Threshold for matrix size to switch from Strassen's to standard
36+
* multiplication. Tuned by performance testing, 64 is a common value.
37+
*/
38+
private static final int THRESHOLD = 64;
39+
40+
private StrassenMatrixMultiplication() {
41+
}
42+
43+
/**
44+
* Multiplies two matrices using Strassen's algorithm.
45+
*
46+
* @param matrixA the first matrix (must be square, n x n)
47+
* @param matrixB the second matrix (must be square, n x n)
48+
* @return the product of the two matrices
49+
* @throws IllegalArgumentException if matrices are not square, not the
50+
* same size, or cannot be multiplied.
51+
*/
52+
public static double[][] multiply(double[][] matrixA, double[][] matrixB) {
53+
// --- 1. VALIDATION ---
54+
if (matrixA == null || matrixB == null) {
55+
throw new IllegalArgumentException("Input matrices cannot be null");
56+
}
57+
if (matrixA.length == 0 || (matrixA.length > 0 && matrixA[0].length == 0)) {
58+
return new double[0][0]; // Handle empty matrix
59+
}
60+
61+
int n = matrixA.length;
62+
if (n != matrixA[0].length || n != matrixB.length || n != matrixB[0].length) {
63+
throw new IllegalArgumentException(
64+
"Strassen's algorithm requires square matrices of the same dimension (n x n)."
65+
);
66+
}
67+
68+
// --- 2. PADDING ---
69+
// Find the next power of 2
70+
int nextPowerOf2 = Integer.highestOneBit(n);
71+
if (nextPowerOf2 < n) {
72+
nextPowerOf2 <<= 1;
73+
}
74+
75+
// Pad matrices to the next power of 2
76+
double[][] paddedA = pad(matrixA, nextPowerOf2);
77+
double[][] paddedB = pad(matrixB, nextPowerOf2);
78+
79+
// --- 3. RECURSION ---
80+
double[][] paddedResult = multiplyRecursive(paddedA, paddedB);
81+
82+
// --- 4. UNPADDING ---
83+
// Extract the original n x n result from the padded result
84+
return unpad(paddedResult, n);
85+
}
86+
87+
/**
88+
* Recursive helper function for Strassen's algorithm.
89+
* Assumes input matrices are square and their size is a power of 2.
90+
*/
91+
private static double[][] multiplyRecursive(double[][] matrixA, double[][] matrixB) {
92+
int n = matrixA.length;
93+
94+
// --- BASE CASE ---
95+
// If the matrix is small, switch to the standard O(n^3) algorithm
96+
if (n <= THRESHOLD) {
97+
return MatrixMultiplication.multiply(matrixA, matrixB);
98+
}
99+
100+
// --- DIVIDE ---
101+
// Split matrices into four n/2 x n/2 sub-matrices
102+
int newSize = n / 2;
103+
double[][] a11 = split(matrixA, 0, 0, newSize);
104+
double[][] a12 = split(matrixA, 0, newSize, newSize);
105+
double[][] a21 = split(matrixA, newSize, 0, newSize);
106+
double[][] a22 = split(matrixA, newSize, newSize, newSize);
107+
108+
double[][] b11 = split(matrixB, 0, 0, newSize);
109+
double[][] b12 = split(matrixB, 0, newSize, newSize);
110+
double[][] b21 = split(matrixB, newSize, 0, newSize);
111+
double[][] b22 = split(matrixB, newSize, newSize, newSize);
112+
113+
// --- CONQUER (7 Recursive Calls) ---
114+
// P1 = A11 * (B12 - B22)
115+
double[][] p1 = multiplyRecursive(a11, subtract(b12, b22));
116+
// P2 = (A11 + A12) * B22
117+
double[][] p2 = multiplyRecursive(add(a11, a12), b22);
118+
// P3 = (A21 + A22) * B11
119+
double[][] p3 = multiplyRecursive(add(a21, a22), b11);
120+
// P4 = A22 * (B21 - B11)
121+
double[][] p4 = multiplyRecursive(a22, subtract(b21, b11));
122+
// P5 = (A11 + A22) * (B11 + B22)
123+
double[][] p5 = multiplyRecursive(add(a11, a22), add(b11, b22));
124+
// P6 = (A12 - A22) * (B21 + B22)
125+
double[][] p6 = multiplyRecursive(subtract(a12, a22), add(b21, b22));
126+
// P7 = (A11 - A21) * (B11 + B12)
127+
double[][] p7 = multiplyRecursive(subtract(a11, a21), add(b11, b12));
128+
129+
// --- COMBINE (Calculate Result Quadrants) ---
130+
// C11 = P5 + P4 - P2 + P6
131+
double[][] c11 = add(subtract(add(p5, p4), p2), p6);
132+
// C12 = P1 + P2
133+
double[][] c12 = add(p1, p2);
134+
// C21 = P3 + P4
135+
double[][] c21 = add(p3, p4);
136+
// C22 = P5 + P1 - P3 - P7
137+
double[][] c22 = subtract(subtract(add(p5, p1), p3), p7);
138+
139+
// Join the four result quadrants into a single matrix
140+
return join(c11, c12, c21, c22);
141+
}
142+
143+
// --- HELPER METHODS ---
144+
/**
145+
* Adds two matrices.
146+
*/
147+
private static double[][] add(double[][] matrixA, double[][] matrixB) {
148+
int n = matrixA.length;
149+
double[][] result = new double[n][n];
150+
for (int i = 0; i < n; i++) {
151+
for (int j = 0; j < n; j++) {
152+
result[i][j] = matrixA[i][j] + matrixB[i][j];
153+
}
154+
}
155+
return result;
156+
}
157+
158+
/**
159+
* Subtracts matrixB from matrixA.
160+
*/
161+
private static double[][] subtract(double[][] matrixA, double[][] matrixB) {
162+
int n = matrixA.length;
163+
double[][] result = new double[n][n];
164+
for (int i = 0; i < n; i++) {
165+
for (int j = 0; j < n; j++) {
166+
result[i][j] = matrixA[i][j] - matrixB[i][j];
167+
}
168+
}
169+
return result;
170+
}
171+
172+
/**
173+
* Splits a parent matrix into a new sub-matrix.
174+
*/
175+
private static double[][] split(
176+
double[][] matrix,
177+
int rowStart,
178+
int colStart,
179+
int size
180+
) {
181+
double[][] subMatrix = new double[size][size];
182+
for (int i = 0; i < size; i++) {
183+
System.arraycopy(
184+
matrix[i + rowStart],
185+
colStart,
186+
subMatrix[i],
187+
0,
188+
size
189+
);
190+
}
191+
return subMatrix;
192+
}
193+
194+
/**
195+
* Joins four sub-matrices into one larger matrix.
196+
*/
197+
private static double[][] join(
198+
double[][] c11,
199+
double[][] c12,
200+
double[][] c21,
201+
double[][] c22
202+
) {
203+
int n = c11.length;
204+
int newSize = n * 2;
205+
double[][] result = new double[newSize][newSize];
206+
for (int i = 0; i < n; i++) {
207+
// C11
208+
System.arraycopy(c11[i], 0, result[i], 0, n);
209+
// C12
210+
System.arraycopy(c12[i], 0, result[i], n, n);
211+
// C21
212+
System.arraycopy(c21[i], 0, result[i + n], 0, n);
213+
// C22
214+
System.arraycopy(c22[i], 0, result[i + n], n, n);
215+
}
216+
return result;
217+
}
218+
219+
/**
220+
* Pads a matrix with zeros to a new larger size.
221+
*/
222+
private static double[][] pad(double[][] matrix, int size) {
223+
if (matrix.length == size) {
224+
return matrix; // No padding needed
225+
}
226+
int n = matrix.length;
227+
double[][] padded = new double[size][size];
228+
for (int i = 0; i < n; i++) {
229+
System.arraycopy(matrix[i], 0, padded[i], 0, matrix[i].length);
230+
}
231+
return padded;
232+
}
233+
234+
/**
235+
* Unpads a matrix to a new smaller size.
236+
*/
237+
private static double[][] unpad(double[][] matrix, int size) {
238+
if (matrix.length == size) {
239+
return matrix; // No unpadding needed
240+
}
241+
double[][] unpadded = new double[size][size];
242+
for (int i = 0; i < size; i++) {
243+
System.arraycopy(matrix[i], 0, unpadded[i], 0, size);
244+
}
245+
return unpadded;
246+
}
247+
}
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
package com.thealgorithms.matrix;
2+
3+
import static org.junit.jupiter.api.Assertions.*;
4+
5+
import org.junit.jupiter.api.Test;
6+
7+
/**
8+
* Unit tests for the StrassenMatrixMultiplication class.
9+
*/
10+
class StrassenMatrixMultiplicationTest {
11+
12+
// Define some test matrices
13+
private static final double[][] MATRIX_2X2_A = {{1, 2}, {3, 4}};
14+
private static final double[][] MATRIX_2X2_B = {{5, 6}, {7, 8}};
15+
private static final double[][] EXPECTED_2X2_PRODUCT = {{19, 22}, {43, 50}};
16+
17+
private static final double[][] MATRIX_4X4_A = {
18+
{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}};
19+
private static final double[][] MATRIX_4X4_B = {
20+
{5, 8, 1, 2}, {6, 7, 3, 0}, {4, 5, 9, 1}, {2, 6, 10, 14}};
21+
private static final double[][] EXPECTED_4X4_PRODUCT = {
22+
{37, 61, 74, 61}, {105, 165, 166, 129}, {173, 269, 258, 197}, {241, 373, 350, 265}};
23+
24+
private static final double[][] MATRIX_3X3_A = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}};
25+
private static final double[][] MATRIX_3X3_B = {{9, 8, 7}, {6, 5, 4}, {3, 2, 1}};
26+
private static final double[][] EXPECTED_3X3_PRODUCT = {{30, 24, 18}, {84, 69, 54}, {138, 114, 90}};
27+
28+
private static final double[][] MATRIX_IDENTITY_2X2 = {{1, 0}, {0, 1}};
29+
private static final double[][] MATRIX_ZERO_2X2 = {{0, 0}, {0, 0}};
30+
31+
private static final double[][] MATRIX_NON_SQUARE = {{1, 2, 3}, {4, 5, 6}};
32+
33+
// Tolerance for floating-point comparisons
34+
private static final double DELTA = 1e-9;
35+
36+
/**
37+
* Helper method to compare two matrices with tolerance.
38+
*/
39+
private void assertMatrixEquals(double[][] expected, double[][] actual) {
40+
assertEquals(expected.length, actual.length, "Number of rows differ");
41+
for (int i = 0; i < expected.length; i++) {
42+
assertArrayEquals(
43+
expected[i],
44+
actual[i],
45+
DELTA,
46+
"Row " + i + " differs"
47+
);
48+
}
49+
}
50+
51+
@Test
52+
void testMultiply2x2() {
53+
double[][] result = StrassenMatrixMultiplication.multiply(MATRIX_2X2_A, MATRIX_2X2_B);
54+
assertMatrixEquals(EXPECTED_2X2_PRODUCT, result);
55+
}
56+
57+
@Test
58+
void testMultiply4x4() {
59+
double[][] result = StrassenMatrixMultiplication.multiply(MATRIX_4X4_A, MATRIX_4X4_B);
60+
assertMatrixEquals(EXPECTED_4X4_PRODUCT, result);
61+
}
62+
63+
@Test
64+
void testMultiply3x3RequiresPadding() {
65+
// Strassen requires padding for non-power-of-2 dimensions
66+
double[][] result = StrassenMatrixMultiplication.multiply(MATRIX_3X3_A, MATRIX_3X3_B);
67+
assertMatrixEquals(EXPECTED_3X3_PRODUCT, result);
68+
}
69+
70+
@Test
71+
void testMultiplyByIdentity() {
72+
double[][] result = StrassenMatrixMultiplication.multiply(MATRIX_2X2_A, MATRIX_IDENTITY_2X2);
73+
assertMatrixEquals(MATRIX_2X2_A, result);
74+
75+
double[][] result2 = StrassenMatrixMultiplication.multiply(MATRIX_IDENTITY_2X2, MATRIX_2X2_A);
76+
assertMatrixEquals(MATRIX_2X2_A, result2);
77+
}
78+
79+
@Test
80+
void testMultiplyByZero() {
81+
double[][] result = StrassenMatrixMultiplication.multiply(MATRIX_2X2_A, MATRIX_ZERO_2X2);
82+
assertMatrixEquals(MATRIX_ZERO_2X2, result);
83+
84+
double[][] result2 = StrassenMatrixMultiplication.multiply(MATRIX_ZERO_2X2, MATRIX_2X2_A);
85+
assertMatrixEquals(MATRIX_ZERO_2X2, result2);
86+
}
87+
@Test
88+
void testMultiply1x1() {
89+
double[][] a = {{5.0}};
90+
double[][] b = {{6.0}};
91+
double[][] expected = {{30.0}};
92+
double[][] result = StrassenMatrixMultiplication.multiply(a, b);
93+
assertMatrixEquals(expected, result);
94+
}
95+
96+
97+
@Test
98+
void testNullInput() {
99+
assertThrows(
100+
IllegalArgumentException.class,
101+
() -> StrassenMatrixMultiplication.multiply(null, MATRIX_2X2_B),
102+
"Multiplying with null matrix A should throw exception"
103+
);
104+
assertThrows(
105+
IllegalArgumentException.class,
106+
() -> StrassenMatrixMultiplication.multiply(MATRIX_2X2_A, null),
107+
"Multiplying with null matrix B should throw exception"
108+
);
109+
}
110+
111+
@Test
112+
void testNonSquareInput() {
113+
assertThrows(
114+
IllegalArgumentException.class,
115+
() -> StrassenMatrixMultiplication.multiply(MATRIX_NON_SQUARE, MATRIX_2X2_B),
116+
"Multiplying non-square matrix A should throw exception"
117+
);
118+
assertThrows(
119+
IllegalArgumentException.class,
120+
() -> StrassenMatrixMultiplication.multiply(MATRIX_2X2_A, MATRIX_NON_SQUARE),
121+
"Multiplying non-square matrix B should throw exception"
122+
);
123+
}
124+
125+
@Test
126+
void testDifferentSquareDimensions() {
127+
assertThrows(
128+
IllegalArgumentException.class,
129+
() -> StrassenMatrixMultiplication.multiply(MATRIX_2X2_A, MATRIX_3X3_A),
130+
"Multiplying matrices of different square dimensions should throw exception"
131+
);
132+
}
133+
134+
@Test
135+
void testEmptyMatrix() {
136+
double[][] empty = {};
137+
double[][] result = StrassenMatrixMultiplication.multiply(empty, empty);
138+
assertEquals(0, result.length, "Multiplying empty matrices should result in an empty matrix");
139+
140+
double[][] emptyRows = {{}};
141+
assertThrows(
142+
IllegalArgumentException.class, // Or handle as empty depending on strictness
143+
() -> StrassenMatrixMultiplication.multiply(emptyRows, emptyRows),
144+
"Multiplying matrices with zero columns might throw or return empty"
145+
);
146+
}
147+
}

0 commit comments

Comments
 (0)