Skip to content

Commit 10aee91

Browse files
committed
test: refactor matrix multiplication tests and improve random matrix generation
1 parent 8505401 commit 10aee91

File tree

3 files changed

+59
-40
lines changed

3 files changed

+59
-40
lines changed

src/__tests__/index.test.js

Lines changed: 58 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { Matrix } from 'ml-matrix';
12
import { describe, expect, it } from 'vitest';
23

34
import { SparseMatrix } from '../index.js';
@@ -38,33 +39,46 @@ describe('Sparse Matrix', () => {
3839
expect(m3.cardinality).toBe(1);
3940

4041
expect(m3.get(0, 1)).toBe(2);
41-
expect(m3.to2DArray()).toStrictEqual([
42+
expectMatrixClose(m3.to2DArray(), [
4243
[0, 2],
4344
[0, 0],
4445
]);
4546

4647
// Compare with dense multiplication
47-
const denseM1 = m1.to2DArray();
48-
const denseM2 = m2.to2DArray();
49-
const expectedDense = denseMatrixMultiply(denseM1, denseM2);
50-
expect(m3.to2DArray()).toStrictEqual(expectedDense);
48+
const denseM1 = new Matrix(m1.to2DArray());
49+
const denseM2 = new Matrix(m2.to2DArray());
50+
const expectedDense = denseM1.mmul(denseM2);
51+
expectMatrixClose(m3.to2DArray(), expectedDense.to2DArray());
5152
});
5253

5354
it('mmul', () => {
5455
const size = 32;
5556
const density = 0.1;
56-
const m1 = randomSparseMatrix(size, size, density);
57-
const m2 = randomSparseMatrix(size, size, density);
58-
let m3 = m1.mmul(m2);
59-
60-
const denseM1 = m1.to2DArray();
61-
const denseM2 = m2.to2DArray();
62-
63-
const newSparse = new SparseMatrix(denseM1);
64-
expect(newSparse.to2DArray()).toStrictEqual(denseM1);
65-
const expectedDense = denseMatrixMultiply(denseM1, denseM2);
57+
const A = randomMatrix(size, size, density * size ** 2);
58+
const B = randomMatrix(size, size, density * size ** 2);
59+
const m1 = new SparseMatrix(A);
60+
const m2 = new SparseMatrix(B);
61+
const m3 = m1.mmul(m2);
62+
63+
const denseM1 = new Matrix(A);
64+
const denseM2 = new Matrix(B);
65+
const expectedDense = denseM1.mmul(denseM2);
66+
expectMatrixClose(m3.to2DArray(), expectedDense.to2DArray());
67+
});
6668

67-
expect(m3.to2DArray()).toStrictEqual(expectedDense);
69+
it('mmul with low density', () => {
70+
const size = 128;
71+
const cardinality = 64;
72+
const A = randomMatrix(size, size, cardinality);
73+
const B = randomMatrix(size, size, cardinality);
74+
const m1 = new SparseMatrix(A);
75+
const m2 = new SparseMatrix(B);
76+
const m3 = m1.mmul(m2);
77+
78+
const denseM1 = new Matrix(A);
79+
const denseM2 = new Matrix(B);
80+
const expectedDense = denseM1.mmul(denseM2);
81+
expectMatrixClose(m3.to2DArray(), expectedDense.to2DArray());
6882
});
6983

7084
it('kronecker', () => {
@@ -168,31 +182,37 @@ describe('Banded matrices', () => {
168182
});
169183
});
170184

171-
function denseMatrixMultiply(A, B) {
172-
const rowsA = A.length;
173-
const colsA = A[0].length;
174-
const colsB = B[0].length;
175-
const result = Array.from({ length: rowsA }, () => Array(colsB).fill(0));
176-
for (let i = 0; i < rowsA; i++) {
177-
for (let j = 0; j < colsB; j++) {
178-
for (let k = 0; k < colsA; k++) {
179-
result[i][j] += A[i][k] * B[k][j];
180-
}
185+
/**
186+
* Helper to compare two 2D arrays element-wise using toBeCloseTo.
187+
*/
188+
function expectMatrixClose(received, expected, precision = 6) {
189+
expect(received.length).toBe(expected.length);
190+
for (let i = 0; i < received.length; i++) {
191+
expect(received[i].length).toBe(expected[i].length);
192+
for (let j = 0; j < received[i].length; j++) {
193+
expect(received[i][j]).toBeCloseTo(expected[i][j], precision);
181194
}
182195
}
183-
return result;
184196
}
185197

186-
function randomSparseMatrix(rows, cols, density = 0.01) {
187-
const matrix = [];
188-
for (let i = 0; i < rows; i++) {
189-
const row = new Float64Array(cols);
190-
for (let j = 0; j < cols; j++) {
191-
if (Math.random() < density) {
192-
row[j] = Math.random() * 10;
193-
}
194-
}
195-
matrix.push(row);
198+
function randomMatrix(rows, cols, cardinality) {
199+
const total = rows * cols;
200+
const positions = new Set();
201+
202+
// Generate unique random positions
203+
while (positions.size < cardinality) {
204+
positions.add(Math.floor(Math.random() * total));
196205
}
197-
return new SparseMatrix(matrix);
206+
207+
// Build the matrix with zeros
208+
const matrix = Array.from({ length: rows }, () => new Float64Array(cols));
209+
210+
// Assign random values to the selected positions
211+
for (const pos of positions) {
212+
const i = Math.floor(pos / cols);
213+
const j = pos % cols;
214+
matrix[i][j] = Math.random() * 10;
215+
}
216+
217+
return matrix;
198218
}

src/index.js

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import HashTable from 'ml-hash-table';
2+
23
import { cooToCsr } from './utils/cooToCsr.js';
34

45
/** @typedef {(row: number, column: number, value: number) => void} WithEachNonZeroCallback */
@@ -169,13 +170,11 @@ export class SparseMatrix {
169170
'Number of columns of left matrix are not equal to number of rows of right matrix.',
170171
);
171172
}
172-
173173
if (this.cardinality < 42 && other.cardinality < 42) {
174174
return this._mmulSmall(other);
175175
} else if (other.rows > 100 && other.cardinality < 100) {
176176
return this._mmulLowDensity(other);
177177
}
178-
179178
return this._mmulMediumDensity(other);
180179
}
181180

File renamed without changes.

0 commit comments

Comments
 (0)