Skip to content

Commit 8cef7c0

Browse files
committed
chore: refactor getNonZeros method to simplify format handling and improve performance
1 parent 63b64ed commit 8cef7c0

File tree

2 files changed

+91
-97
lines changed

2 files changed

+91
-97
lines changed

src/__tests__/getNonZeros.test.js

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { describe, it, expect } from 'vitest';
2+
23
import { SparseMatrix } from '../index.js';
34

45
describe('Sparse Matrix', () => {
@@ -22,17 +23,10 @@ describe('Sparse Matrix', () => {
2223
});
2324

2425
// CSR format
25-
expect(m2.getNonZeros({ format: 'csr' })).toEqual({
26+
expect(m2.getNonZeros({ format: true })).toEqual({
2627
rows: Float64Array.from([0, 0, 4, 7, 7, 11]),
2728
columns: Float64Array.from([0, 3, 4, 5, 1, 4, 5, 0, 3, 4, 5]),
2829
values: Float64Array.from([1, 2, 1, 1, 3, 5, 5, 1, 1, 9, 9]),
2930
});
30-
31-
//CSC format
32-
expect(m2.getNonZeros({ format: 'csc' })).toEqual({
33-
rows: Float64Array.from([1, 4, 2, 1, 4, 1, 2, 4, 1, 2, 4]),
34-
columns: Float64Array.from([0, 2, 3, 3, 5, 8, 11]),
35-
values: Float64Array.from([1, 1, 3, 2, 1, 1, 5, 9, 1, 5, 9]),
36-
});
3731
});
3832
});

src/index.js

Lines changed: 89 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -163,40 +163,97 @@ export class SparseMatrix {
163163
);
164164
}
165165

166+
if (this.cardinality < 42 && other.cardinality < 42) {
167+
return this._mmulSmall(other);
168+
} else if (other.rows > 100 && other.cardinality < 110) {
169+
return this._mmulLowDensity(other);
170+
}
171+
172+
return this._mmulMediumDensity(other);
173+
}
174+
175+
_mmulSmall(other) {
166176
const m = this.rows;
167177
const p = other.columns;
178+
const {
179+
columns: otherCols,
180+
rows: otherRows,
181+
values: otherValues,
182+
} = other.getNonZeros();
168183

169-
const result = matrixCreateEmpty(m, p);
184+
const nbOtherActive = otherCols.length;
185+
const result = new SparseMatrix(m, p);
186+
this.withEachNonZero((i, j, v1) => {
187+
for (let o = 0; o < nbOtherActive; o++) {
188+
if (j === otherRows[o]) {
189+
const l = otherCols[o];
190+
result.set(i, l, result.get(i, l) + otherValues[o] * v1);
191+
}
192+
}
193+
});
194+
return result;
195+
}
170196

197+
_mmulLowDensity(other) {
198+
const m = this.rows;
199+
const p = other.columns;
171200
const {
172201
columns: otherCols,
173202
rows: otherRows,
174203
values: otherValues,
175-
} = other.getNonZeros({ format: 'csr' });
204+
} = other.getNonZeros();
176205
const {
177206
columns: thisCols,
178207
rows: thisRows,
179208
values: thisValues,
180-
} = this.getNonZeros({ format: 'csc' });
181-
182-
const thisNbCols = this.columns;
183-
for (let t = 0; t < thisNbCols; t++) {
184-
const tStart = thisCols[t];
185-
const tEnd = thisCols[t + 1];
186-
const oStart = otherRows[t];
187-
const oEnd = otherRows[t + 1];
188-
for (let f = tStart; f < tEnd; f++) {
189-
for (let k = oStart; k < oEnd; k++) {
190-
const i = thisRows[f];
191-
const l = otherCols[k];
192-
result[i][l] += thisValues[f] * otherValues[k];
209+
} = this.getNonZeros();
210+
211+
const result = new SparseMatrix(m, p);
212+
const nbOtherActive = otherCols.length;
213+
const nbThisActive = thisCols.length;
214+
for (let t = 0; t < nbThisActive; t++) {
215+
const i = thisRows[t];
216+
const j = thisCols[t];
217+
for (let o = 0; o < nbOtherActive; o++) {
218+
if (j === otherRows[o]) {
219+
const l = otherCols[o];
220+
result.set(i, l, result.get(i, l) + otherValues[o] * thisValues[t]);
193221
}
194222
}
195223
}
196-
197-
return new SparseMatrix(result);
224+
// console.log(result.cardinality);
225+
return result;
198226
}
199227

228+
_mmulMediumDensity(other) {
229+
const m = this.rows;
230+
const p = other.columns;
231+
const {
232+
columns: otherCols,
233+
rows: otherRows,
234+
values: otherValues,
235+
} = other.getNonZeros({ format: 'csr' });
236+
const {
237+
columns: thisCols,
238+
rows: thisRows,
239+
values: thisValues,
240+
} = this.getNonZeros();
241+
242+
const result = new SparseMatrix(m, p);
243+
const nbThisActive = thisCols.length;
244+
for (let t = 0; t < nbThisActive; t++) {
245+
const i = thisRows[t];
246+
const j = thisCols[t];
247+
const oStart = otherRows[j];
248+
const oEnd = otherRows[j + 1];
249+
for (let k = oStart; k < oEnd; k++) {
250+
const l = otherCols[k];
251+
result.set(i, l, result.get(i, l) + otherValues[k] * thisValues[t]);
252+
}
253+
}
254+
255+
return result;
256+
}
200257
/**
201258
* @param {SparseMatrix} other
202259
* @returns {SparseMatrix}
@@ -292,11 +349,10 @@ export class SparseMatrix {
292349
* Returns the non-zero elements of the matrix in coordinate (COO), CSR, or CSC format.
293350
*
294351
* @param {Object} [options={}] - Options for output format and sorting.
295-
* @param {'csr'|'csc'} [options.format] - If specified, returns the result in CSR or CSC format. Otherwise, returns COO format.
352+
* @param {boolean} [options.format] - If specified, returns the result in CSR or CSC format. Otherwise, returns COO format.
296353
* @param {boolean} [options.sort] - If true, sorts the non-zero elements by their indices.
297354
* @returns {Object} If no format is specified, returns an object with Float64Array `rows`, `columns`, and `values` (COO format).
298-
* If format is 'csr', returns { rows, columns, values } in CSR format.
299-
* If format is 'csc', returns { rows, columns, values } in CSC format.
355+
* If format is true, returns { rows, columns, values } in CSR format.
300356
* @throws {Error} If an unsupported format is specified.
301357
*/
302358
getNonZeros(options = {}) {
@@ -318,16 +374,9 @@ export class SparseMatrix {
318374
idx++;
319375
}, sort);
320376

321-
if (!format) return { rows, columns, values };
322-
323-
if (!['csr', 'csc'].includes(format)) {
324-
throw new Error(`format ${format} is not supported`);
325-
}
326-
327-
const csrMatrix = cooToCsr({ rows, columns, values }, this.rows);
328-
return format.toLowerCase() === 'csc'
329-
? csrToCsc(csrMatrix, this.columns)
330-
: csrMatrix;
377+
return format
378+
? cooToCsr({ rows, columns, values }, this.rows)
379+
: { rows, columns, values };
331380
}
332381

333382
/**
@@ -1469,46 +1518,6 @@ SparseMatrix.prototype.klass = 'Matrix';
14691518
SparseMatrix.identity = SparseMatrix.eye;
14701519
SparseMatrix.prototype.tensorProduct = SparseMatrix.prototype.kroneckerProduct;
14711520

1472-
/**
1473-
* Converts a matrix from Compressed Sparse Row (CSR) format to Compressed Sparse Column (CSC) format.
1474-
* @param {Object} csrMatrix - The matrix in CSR format with properties: values, columns, rows.
1475-
* @param {number} numCols - The number of columns in the matrix.
1476-
* @returns {{rows: Float64Array, columns: Float64Array, values: Float64Array}} The matrix in CSC format.
1477-
*/
1478-
function csrToCsc(csrMatrix, numCols) {
1479-
const {
1480-
values: csrValues,
1481-
columns: csrColIndices,
1482-
rows: csrRowPtr,
1483-
} = csrMatrix;
1484-
1485-
const cscValues = new Float64Array(csrValues.length);
1486-
const cscRowIndices = new Float64Array(csrValues.length);
1487-
const cscColPtr = new Float64Array(numCols + 1);
1488-
1489-
for (let i = 0; i < csrColIndices.length; i++) {
1490-
cscColPtr[csrColIndices[i] + 1]++;
1491-
}
1492-
1493-
for (let i = 1; i <= numCols; i++) {
1494-
cscColPtr[i] += cscColPtr[i - 1];
1495-
}
1496-
1497-
const next = cscColPtr.slice();
1498-
1499-
for (let row = 0; row < csrRowPtr.length - 1; row++) {
1500-
for (let j = csrRowPtr[row], i = 0; j < csrRowPtr[row + 1]; j++, i++) {
1501-
const col = csrColIndices[j];
1502-
const pos = next[col];
1503-
cscValues[pos] = csrValues[j];
1504-
cscRowIndices[pos] = row;
1505-
next[col]++;
1506-
}
1507-
}
1508-
1509-
return { rows: cscRowIndices, columns: cscColPtr, values: cscValues };
1510-
}
1511-
15121521
/**
15131522
* Converts a matrix from Coordinate (COO) format to Compressed Sparse Row (CSR) format.
15141523
* @param {Object} cooMatrix - The matrix in COO format with properties: values, columns, rows.
@@ -1518,26 +1527,17 @@ function csrToCsc(csrMatrix, numCols) {
15181527
function cooToCsr(cooMatrix, nbRows) {
15191528
const { values, columns, rows } = cooMatrix;
15201529
const csrRowPtr = new Float64Array(nbRows + 1);
1521-
const length = values.length;
1522-
let currentRow = rows[0];
1523-
for (let index = 0; index < length; ) {
1524-
while (currentRow === rows[index] && index < length) ++index;
1525-
csrRowPtr[currentRow + 1] = index;
1526-
currentRow += 1;
1530+
1531+
// Count non-zeros per row
1532+
const numberOfNonZeros = rows.length;
1533+
for (let i = 0; i < numberOfNonZeros; i++) {
1534+
csrRowPtr[rows[i] + 1]++;
15271535
}
1528-
return { rows: csrRowPtr, columns, values };
1529-
}
15301536

1531-
/**
1532-
* Creates an empty 2D matrix (array of Float64Array) with the given dimensions.
1533-
* @param {number} nbRows - Number of rows.
1534-
* @param {number} nbColumns - Number of columns.
1535-
* @returns {Float64Array[]} A 2D array representing the matrix.
1536-
*/
1537-
function matrixCreateEmpty(nbRows, nbColumns) {
1538-
const newMatrix = [];
1539-
for (let row = 0; row < nbRows; row++) {
1540-
newMatrix.push(new Float64Array(nbColumns));
1537+
// Compute cumulative sum
1538+
for (let i = 1; i <= nbRows; i++) {
1539+
csrRowPtr[i] += csrRowPtr[i - 1];
15411540
}
1542-
return newMatrix;
1541+
1542+
return { rows: csrRowPtr, columns, values };
15431543
}

0 commit comments

Comments
 (0)