|
| 1 | +import { Matrix } from 'ml-matrix'; |
1 | 2 | import { describe, expect, it } from 'vitest';
|
2 | 3 |
|
3 | 4 | import { SparseMatrix } from '../index.js';
|
@@ -38,33 +39,46 @@ describe('Sparse Matrix', () => {
|
38 | 39 | expect(m3.cardinality).toBe(1);
|
39 | 40 |
|
40 | 41 | expect(m3.get(0, 1)).toBe(2);
|
41 |
| - expect(m3.to2DArray()).toStrictEqual([ |
| 42 | + expectMatrixClose(m3.to2DArray(), [ |
42 | 43 | [0, 2],
|
43 | 44 | [0, 0],
|
44 | 45 | ]);
|
45 | 46 |
|
46 | 47 | // Compare with dense multiplication
|
47 |
| - const denseM1 = m1.to2DArray(); |
48 |
| - const denseM2 = m2.to2DArray(); |
49 |
| - const expectedDense = denseMatrixMultiply(denseM1, denseM2); |
50 |
| - expect(m3.to2DArray()).toStrictEqual(expectedDense); |
| 48 | + const denseM1 = new Matrix(m1.to2DArray()); |
| 49 | + const denseM2 = new Matrix(m2.to2DArray()); |
| 50 | + const expectedDense = denseM1.mmul(denseM2); |
| 51 | + expectMatrixClose(m3.to2DArray(), expectedDense.to2DArray()); |
51 | 52 | });
|
52 | 53 |
|
53 | 54 | it('mmul', () => {
|
54 | 55 | const size = 32;
|
55 | 56 | const density = 0.1;
|
56 |
| - const m1 = randomSparseMatrix(size, size, density); |
57 |
| - const m2 = randomSparseMatrix(size, size, density); |
58 |
| - let m3 = m1.mmul(m2); |
59 |
| - |
60 |
| - const denseM1 = m1.to2DArray(); |
61 |
| - const denseM2 = m2.to2DArray(); |
62 |
| - |
63 |
| - const newSparse = new SparseMatrix(denseM1); |
64 |
| - expect(newSparse.to2DArray()).toStrictEqual(denseM1); |
65 |
| - const expectedDense = denseMatrixMultiply(denseM1, denseM2); |
| 57 | + const A = randomMatrix(size, size, density * size ** 2); |
| 58 | + const B = randomMatrix(size, size, density * size ** 2); |
| 59 | + const m1 = new SparseMatrix(A); |
| 60 | + const m2 = new SparseMatrix(B); |
| 61 | + const m3 = m1.mmul(m2); |
| 62 | + |
| 63 | + const denseM1 = new Matrix(A); |
| 64 | + const denseM2 = new Matrix(B); |
| 65 | + const expectedDense = denseM1.mmul(denseM2); |
| 66 | + expectMatrixClose(m3.to2DArray(), expectedDense.to2DArray()); |
| 67 | + }); |
66 | 68 |
|
67 |
| - expect(m3.to2DArray()).toStrictEqual(expectedDense); |
| 69 | + it('mmul with low density', () => { |
| 70 | + const size = 128; |
| 71 | + const cardinality = 64; |
| 72 | + const A = randomMatrix(size, size, cardinality); |
| 73 | + const B = randomMatrix(size, size, cardinality); |
| 74 | + const m1 = new SparseMatrix(A); |
| 75 | + const m2 = new SparseMatrix(B); |
| 76 | + const m3 = m1.mmul(m2); |
| 77 | + |
| 78 | + const denseM1 = new Matrix(A); |
| 79 | + const denseM2 = new Matrix(B); |
| 80 | + const expectedDense = denseM1.mmul(denseM2); |
| 81 | + expectMatrixClose(m3.to2DArray(), expectedDense.to2DArray()); |
68 | 82 | });
|
69 | 83 |
|
70 | 84 | it('kronecker', () => {
|
@@ -168,31 +182,37 @@ describe('Banded matrices', () => {
|
168 | 182 | });
|
169 | 183 | });
|
170 | 184 |
|
171 |
| -function denseMatrixMultiply(A, B) { |
172 |
| - const rowsA = A.length; |
173 |
| - const colsA = A[0].length; |
174 |
| - const colsB = B[0].length; |
175 |
| - const result = Array.from({ length: rowsA }, () => Array(colsB).fill(0)); |
176 |
| - for (let i = 0; i < rowsA; i++) { |
177 |
| - for (let j = 0; j < colsB; j++) { |
178 |
| - for (let k = 0; k < colsA; k++) { |
179 |
| - result[i][j] += A[i][k] * B[k][j]; |
180 |
| - } |
| 185 | +/** |
| 186 | + * Helper to compare two 2D arrays element-wise using toBeCloseTo. |
| 187 | + */ |
| 188 | +function expectMatrixClose(received, expected, precision = 6) { |
| 189 | + expect(received.length).toBe(expected.length); |
| 190 | + for (let i = 0; i < received.length; i++) { |
| 191 | + expect(received[i].length).toBe(expected[i].length); |
| 192 | + for (let j = 0; j < received[i].length; j++) { |
| 193 | + expect(received[i][j]).toBeCloseTo(expected[i][j], precision); |
181 | 194 | }
|
182 | 195 | }
|
183 |
| - return result; |
184 | 196 | }
|
185 | 197 |
|
186 |
| -function randomSparseMatrix(rows, cols, density = 0.01) { |
187 |
| - const matrix = []; |
188 |
| - for (let i = 0; i < rows; i++) { |
189 |
| - const row = new Float64Array(cols); |
190 |
| - for (let j = 0; j < cols; j++) { |
191 |
| - if (Math.random() < density) { |
192 |
| - row[j] = Math.random() * 10; |
193 |
| - } |
194 |
| - } |
195 |
| - matrix.push(row); |
| 198 | +function randomMatrix(rows, cols, cardinality) { |
| 199 | + const total = rows * cols; |
| 200 | + const positions = new Set(); |
| 201 | + |
| 202 | + // Generate unique random positions |
| 203 | + while (positions.size < cardinality) { |
| 204 | + positions.add(Math.floor(Math.random() * total)); |
196 | 205 | }
|
197 |
| - return new SparseMatrix(matrix); |
| 206 | + |
| 207 | + // Build the matrix with zeros |
| 208 | + const matrix = Array.from({ length: rows }, () => new Float64Array(cols)); |
| 209 | + |
| 210 | + // Assign random values to the selected positions |
| 211 | + for (const pos of positions) { |
| 212 | + const i = Math.floor(pos / cols); |
| 213 | + const j = pos % cols; |
| 214 | + matrix[i][j] = Math.random() * 10; |
| 215 | + } |
| 216 | + |
| 217 | + return matrix; |
198 | 218 | }
|
0 commit comments