Skip to content

Commit 1bacd59

Browse files
committed
feat: add new multiplication methods and improve getNonZeros function
1 parent 104e3e1 commit 1bacd59

File tree

1 file changed

+135
-11
lines changed

1 file changed

+135
-11
lines changed

src/index.js

Lines changed: 135 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,29 @@ export class SparseMatrix {
152152
} else {
153153
this.elements.set(row * this.columns + column, value);
154154
}
155+
155156
return this;
156157
}
157158

158159
/**
159160
* @param {SparseMatrix} other
160161
* @returns {SparseMatrix}
161162
*/
162-
mmul(other) {
163+
mulNew(other) {
164+
if (typeof other !== 'number') {
165+
throw new RangeError('the argument should be a number');
166+
}
167+
168+
// if (this.cardinality / this.columns / this.rows > 0.1)
169+
this.withEachNonZero((i, j, v) => {
170+
// return v * other;
171+
this.set(i, j, v * other);
172+
});
173+
174+
return this;
175+
}
176+
177+
mmulNew(other) {
163178
if (this.columns !== other.rows) {
164179
// eslint-disable-next-line no-console
165180
console.warn(
@@ -181,7 +196,7 @@ export class SparseMatrix {
181196
values: thisValues,
182197
} = this.getNonZeros();
183198

184-
const result = new SparseMatrix(m, p);
199+
const result = new SparseMatrix(m, p, { initialCapacity: m * p });
185200

186201
const nbOtherActive = otherCols.length;
187202
const nbThisActive = thisCols.length;
@@ -199,6 +214,53 @@ export class SparseMatrix {
199214
return result;
200215
}
201216

217+
mmul(other) {
218+
if (this.columns !== other.rows) {
219+
// eslint-disable-next-line no-console
220+
console.warn(
221+
'Number of columns of left matrix are not equal to number of rows of right matrix.',
222+
);
223+
}
224+
225+
const m = this.rows;
226+
const p = other.columns;
227+
228+
const {
229+
columns: otherCols,
230+
rows: otherRows,
231+
values: otherValues,
232+
} = other.getNonZeros('csr');
233+
const {
234+
columns: thisCols,
235+
rows: thisRows,
236+
values: thisValues,
237+
} = this.getNonZeros('csc');
238+
239+
const result = new SparseMatrix(m, p, { initialCapacity: m * p + 20 });
240+
241+
for (let t = 0; t < thisCols.length - 1; t++) {
242+
const j = t;
243+
const tValues = thisValues.subarray(thisCols[t], thisCols[t + 1]);
244+
const tRows = thisRows.subarray(thisCols[t], thisCols[t + 1]);
245+
let initOther = 0;
246+
for (let o = initOther; o < otherRows.length - 1; o++) {
247+
if (o === j) {
248+
initOther++;
249+
const oValues = otherValues.subarray(otherRows[o], otherRows[o + 1]);
250+
const oCols = otherCols.subarray(otherRows[o], otherRows[o + 1]);
251+
for (let f = 0; f < tValues.length; f++) {
252+
for (let k = 0; k < oValues.length; k++) {
253+
const i = tRows[f];
254+
const l = oCols[k];
255+
result.set(i, l, result.get(i, l) + tValues[f] * oValues[k]);
256+
}
257+
}
258+
} else if (j < o) break;
259+
}
260+
}
261+
return result;
262+
}
263+
202264
/**
203265
* @param {SparseMatrix} other
204266
* @returns {SparseMatrix}
@@ -210,7 +272,7 @@ export class SparseMatrix {
210272
const q = other.columns;
211273

212274
const result = new SparseMatrix(m * p, n * q, {
213-
initialCapacity: this.cardinality * other.cardinality,
275+
initialCapacity: this.cardinality * other.cardinality + 20,
214276
});
215277

216278
const {
@@ -285,22 +347,30 @@ export class SparseMatrix {
285347
return this;
286348
}
287349

288-
getNonZeros() {
350+
//'csr' | 'csc' | undefined
351+
getNonZeros(format) {
289352
const cardinality = this.cardinality;
290-
/** @type {number[]} */
291-
const rows = new Array(cardinality);
292-
/** @type {number[]} */
293-
const columns = new Array(cardinality);
294-
/** @type {number[]} */
295-
const values = new Array(cardinality);
353+
/** @type {Float64Array} */
354+
const rows = new Float64Array(cardinality);
355+
/** @type {Float64Array} */
356+
const columns = new Float64Array(cardinality);
357+
/** @type {Float64Array} */
358+
const values = new Float64Array(cardinality);
359+
296360
let idx = 0;
297361
this.withEachNonZero((i, j, value) => {
298362
rows[idx] = i;
299363
columns[idx] = j;
300364
values[idx] = value;
301365
idx++;
302366
});
303-
return { rows, columns, values };
367+
368+
const cooMatrix = { rows, columns, values };
369+
370+
if (!format) return cooMatrix;
371+
372+
const csrMatrix = cooToCsr(cooMatrix, this.rows);
373+
return format === 'csc' ? csrToCsc(csrMatrix, this.columns) : csrMatrix;
304374
}
305375

306376
/**
@@ -1433,3 +1503,57 @@ SparseMatrix.prototype.klass = 'Matrix';
14331503

14341504
SparseMatrix.identity = SparseMatrix.eye;
14351505
SparseMatrix.prototype.tensorProduct = SparseMatrix.prototype.kroneckerProduct;
1506+
1507+
function csrToCsc(csrMatrix, numCols) {
1508+
const {
1509+
values: csrValues,
1510+
columns: csrColIndices,
1511+
rows: csrRowPtr,
1512+
} = csrMatrix;
1513+
// Initialize CSC arrays
1514+
const cscValues = new Float64Array(csrValues.length);
1515+
const cscRowIndices = new Float64Array(csrValues.length);
1516+
const cscColPtr = new Float64Array(numCols + 1);
1517+
1518+
// Count non-zeros per column
1519+
for (let i = 0; i < csrColIndices.length; i++) {
1520+
cscColPtr[csrColIndices[i] + 1]++;
1521+
}
1522+
1523+
// Compute column pointers (prefix sum)
1524+
for (let i = 1; i <= numCols; i++) {
1525+
cscColPtr[i] += cscColPtr[i - 1];
1526+
}
1527+
1528+
// Temporary copy for filling values
1529+
const next = cscColPtr.slice();
1530+
1531+
// Fill CSC arrays
1532+
for (let row = 0; row < csrRowPtr.length - 1; row++) {
1533+
for (let j = csrRowPtr[row], i = 0; j < csrRowPtr[row + 1]; j++, i++) {
1534+
const col = csrColIndices[j];
1535+
const pos = next[col];
1536+
cscValues[pos] = csrValues[j];
1537+
cscRowIndices[pos] = row;
1538+
next[col]++;
1539+
}
1540+
// if (row === 1) break;
1541+
}
1542+
1543+
return { rows: cscRowIndices, columns: cscColPtr, values: cscValues };
1544+
}
1545+
1546+
function cooToCsr(cooMatrix, nbRows = 9) {
1547+
const { values, columns, rows } = cooMatrix;
1548+
//could not be the same length
1549+
const csrRowPtr = new Float64Array(nbRows + 1);
1550+
const length = values.length;
1551+
let currentRow = rows[0];
1552+
for (let index = 0; index < length; ) {
1553+
const prev = index;
1554+
while (currentRow === rows[index] && index < length) ++index;
1555+
csrRowPtr[currentRow + 1] = index;
1556+
currentRow += 1;
1557+
}
1558+
return { rows: csrRowPtr, columns, values };
1559+
}

0 commit comments

Comments
 (0)