Skip to content

Commit 8505401

Browse files
committed
chore(getNonZeros): rename format options to csr and move cooToCsr as an utility function
1 parent 163c079 commit 8505401

File tree

3 files changed

+74
-43
lines changed

3 files changed

+74
-43
lines changed

src/__tests__/getNonZeros.test.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ describe('Sparse Matrix', () => {
2323
});
2424

2525
// CSR format
26-
expect(m2.getNonZeros({ format: true })).toEqual({
26+
expect(m2.getNonZeros({ csr: true })).toEqual({
2727
rows: Float64Array.from([0, 0, 4, 7, 7, 11]),
2828
columns: Float64Array.from([0, 3, 4, 5, 1, 4, 5, 0, 3, 4, 5]),
2929
values: Float64Array.from([1, 2, 1, 1, 3, 5, 5, 1, 1, 9, 9]),

src/index.js

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import HashTable from 'ml-hash-table';
2+
import { cooToCsr } from './utils/cooToCsr.js';
23

34
/** @typedef {(row: number, column: number, value: number) => void} WithEachNonZeroCallback */
45
/** @typedef {(row: number, column: number, value: number) => number | false} ForEachNonZeroCallback */
@@ -157,13 +158,14 @@ export class SparseMatrix {
157158
}
158159

159160
/**
161+
* Matrix multiplication, does not modify the current instance.
160162
* @param {SparseMatrix} other
161-
* @returns {SparseMatrix}
163+
* @returns {SparseMatrix} returns a new matrix instance.
164+
* @throws {Error} If the number of columns of this matrix does not match the number of rows of the other matrix.
162165
*/
163166
mmul(other) {
164167
if (this.columns !== other.rows) {
165-
// eslint-disable-next-line no-console
166-
console.warn(
168+
throw new RangeError(
167169
'Number of columns of left matrix are not equal to number of rows of right matrix.',
168170
);
169171
}
@@ -177,6 +179,13 @@ export class SparseMatrix {
177179
return this._mmulMediumDensity(other);
178180
}
179181

182+
/**
183+
* Matrix multiplication optimized for very small matrices (both cardinalities < 42).
184+
*
185+
* @private
186+
* @param {SparseMatrix} other - The right-hand side matrix to multiply with.
187+
* @returns {SparseMatrix} - The resulting matrix after multiplication.
188+
*/
180189
_mmulSmall(other) {
181190
const m = this.rows;
182191
const p = other.columns;
@@ -199,6 +208,13 @@ export class SparseMatrix {
199208
return result;
200209
}
201210

211+
/**
212+
* Matrix multiplication optimized for low-density right-hand side matrices (other.rows > 100 and other.cardinality < 100).
213+
*
214+
* @private
215+
* @param {SparseMatrix} other - The right-hand side matrix to multiply with.
216+
* @returns {SparseMatrix} - The resulting matrix after multiplication.
217+
*/
202218
_mmulLowDensity(other) {
203219
const m = this.rows;
204220
const p = other.columns;
@@ -230,20 +246,27 @@ export class SparseMatrix {
230246
return result;
231247
}
232248

249+
/**
250+
* Matrix multiplication for medium-density matrices using CSR format for the right-hand side.
251+
*
252+
* @private
253+
* @param {SparseMatrix} other - The right-hand side matrix to multiply with.
254+
* @returns {SparseMatrix} - The resulting matrix after multiplication.
255+
*/
233256
_mmulMediumDensity(other) {
234-
const m = this.rows;
235-
const p = other.columns;
236-
const {
237-
columns: otherCols,
238-
rows: otherRows,
239-
values: otherValues,
240-
} = other.getNonZeros({ format: 'csr' });
241257
const {
242258
columns: thisCols,
243259
rows: thisRows,
244260
values: thisValues,
245261
} = this.getNonZeros();
262+
const {
263+
columns: otherCols,
264+
rows: otherRows,
265+
values: otherValues,
266+
} = other.getNonZeros({ csr: true });
246267

268+
const m = this.rows;
269+
const p = other.columns;
247270
const result = new SparseMatrix(m, p);
248271
const nbThisActive = thisCols.length;
249272
for (let t = 0; t < nbThisActive; t++) {
@@ -358,32 +381,41 @@ export class SparseMatrix {
358381
}
359382

360383
/**
361-
* Returns the non-zero elements of the matrix in coordinate (COO), CSR, or CSC format.
384+
* Returns the non-zero elements of the matrix in coordinates (COO) or CSR format.
385+
*
386+
* **COO (Coordinate) format:**
387+
* Stores the non-zero elements as three arrays: `rows`, `columns`, and `values`, where each index corresponds to a non-zero entry at (row, column) with the given value.
388+
*
389+
* **CSR (Compressed Sparse Row) format:**
390+
* Stores the matrix using three arrays:
391+
* - `rows`: Row pointer array of length `numRows + 1`, where each entry indicates the start of a row in the `columns` and `values` arrays.
392+
* - `columns`: Column indices of non-zero elements.
393+
* - `values`: Non-zero values.
394+
* This format is efficient for row slicing and matrix-vector multiplication.
362395
*
363396
* @param {Object} [options={}] - Options for output format and sorting.
364-
* @param {boolean} [options.format] - If specified, returns the result in CSR or CSC format. Otherwise, returns COO format.
397+
* @param {boolean} [options.csr] - If true, returns the result in CSR format. Otherwise, returns COO format.
365398
* @param {boolean} [options.sort] - If true, sorts the non-zero elements by their indices.
366-
* @returns {Object} If no format is specified, returns an object with Float64Array `rows`, `columns`, and `values` (COO format).
367-
* If format is true, returns { rows, columns, values } in CSR format.
368-
* @throws {Error} If an unsupported format is specified.
399+
* @returns {Object} If `csr` is not specified, returns an object with Float64Array `rows`, `columns`, and `values` (COO format).
400+
* If `csr` is true, returns `{ rows, columns, values }` in CSR format.
369401
*/
370402
getNonZeros(options = {}) {
371403
const cardinality = this.cardinality;
372404
const rows = new Float64Array(cardinality);
373405
const columns = new Float64Array(cardinality);
374406
const values = new Float64Array(cardinality);
375407

376-
const { format, sort = format !== undefined } = options;
408+
const { csr, sort } = options;
377409

378410
let idx = 0;
379411
this.withEachNonZero((i, j, value) => {
380412
rows[idx] = i;
381413
columns[idx] = j;
382414
values[idx] = value;
383415
idx++;
384-
}, sort);
416+
}, sort || csr);
385417

386-
return format
418+
return csr
387419
? cooToCsr({ rows, columns, values }, this.rows)
388420
: { rows, columns, values };
389421
}
@@ -1518,27 +1550,3 @@ SparseMatrix.prototype.klass = 'Matrix';
15181550

15191551
SparseMatrix.identity = SparseMatrix.eye;
15201552
SparseMatrix.prototype.tensorProduct = SparseMatrix.prototype.kroneckerProduct;
1521-
1522-
/**
1523-
* Converts a matrix from Coordinate (COO) format to Compressed Sparse Row (CSR) format.
1524-
* @param {Object} cooMatrix - The matrix in COO format with properties: values, columns, rows.
1525-
* @param {number} nbRows - The number of rows in the matrix.
1526-
* @returns {{rows: Float64Array, columns: Float64Array, values: Float64Array}} The matrix in CSR format.
1527-
*/
1528-
function cooToCsr(cooMatrix, nbRows) {
1529-
const { values, columns, rows } = cooMatrix;
1530-
const csrRowPtr = new Float64Array(nbRows + 1);
1531-
1532-
// Count non-zeros per row
1533-
const numberOfNonZeros = rows.length;
1534-
for (let i = 0; i < numberOfNonZeros; i++) {
1535-
csrRowPtr[rows[i] + 1]++;
1536-
}
1537-
1538-
// Compute cumulative sum
1539-
for (let i = 1; i <= nbRows; i++) {
1540-
csrRowPtr[i] += csrRowPtr[i - 1];
1541-
}
1542-
1543-
return { rows: csrRowPtr, columns, values };
1544-
}

src/utils/cooToCsr.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/**
2+
* Converts a matrix from Coordinate (COO) format to Compressed Sparse Row (CSR) format.
3+
* @param {Object} cooMatrix - The matrix in COO format with properties: values, columns, rows.
4+
* @param {number} nbRows - The number of rows in the matrix.
5+
* @returns {{rows: Float64Array, columns: Float64Array, values: Float64Array}} The matrix in CSR format.
6+
*/
7+
export function cooToCsr(cooMatrix, nbRows) {
8+
const { values, columns, rows } = cooMatrix;
9+
const csrRowPtr = new Float64Array(nbRows + 1);
10+
11+
// Count non-zeros per row
12+
const numberOfNonZeros = rows.length;
13+
for (let i = 0; i < numberOfNonZeros; i++) {
14+
csrRowPtr[rows[i] + 1]++;
15+
}
16+
17+
// Compute cumulative sum
18+
for (let i = 1; i <= nbRows; i++) {
19+
csrRowPtr[i] += csrRowPtr[i - 1];
20+
}
21+
22+
return { rows: csrRowPtr, columns, values };
23+
}

0 commit comments

Comments
 (0)