Skip to content

Commit 9a44f16

Browse files
committed
feat: add new multiplication methods and improve getNonZeros function
1 parent e4a7002 commit 9a44f16

File tree

1 file changed

+131
-8
lines changed

1 file changed

+131
-8
lines changed

src/index.js

Lines changed: 131 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ export class SparseMatrix {
131131
} else {
132132
this.elements.set(row * this.columns + column, value);
133133
}
134+
134135
return this;
135136
}
136137

@@ -147,7 +148,21 @@ export class SparseMatrix {
147148
return this;
148149
}
149150

150-
mmul(other) {
151+
mulNew(other) {
152+
if (typeof other !== 'number') {
153+
throw new RangeError('the argument should be a number');
154+
}
155+
156+
// if (this.cardinality / this.columns / this.rows > 0.1)
157+
this.withEachNonZero((i, j, v) => {
158+
// return v * other;
159+
this.set(i, j, v * other);
160+
});
161+
162+
return this;
163+
}
164+
165+
mmulNew(other) {
151166
if (this.columns !== other.rows) {
152167
// eslint-disable-next-line no-console
153168
console.warn(
@@ -169,7 +184,7 @@ export class SparseMatrix {
169184
values: thisValues,
170185
} = this.getNonZeros();
171186

172-
const result = new SparseMatrix(m, p);
187+
const result = new SparseMatrix(m, p, { initialCapacity: m * p });
173188

174189
const nbOtherActive = otherCols.length;
175190
const nbThisActive = thisCols.length;
@@ -187,14 +202,61 @@ export class SparseMatrix {
187202
return result;
188203
}
189204

205+
mmul(other) {
206+
if (this.columns !== other.rows) {
207+
// eslint-disable-next-line no-console
208+
console.warn(
209+
'Number of columns of left matrix are not equal to number of rows of right matrix.',
210+
);
211+
}
212+
213+
const m = this.rows;
214+
const p = other.columns;
215+
216+
const {
217+
columns: otherCols,
218+
rows: otherRows,
219+
values: otherValues,
220+
} = other.getNonZeros('csr');
221+
const {
222+
columns: thisCols,
223+
rows: thisRows,
224+
values: thisValues,
225+
} = this.getNonZeros('csc');
226+
227+
const result = new SparseMatrix(m, p, { initialCapacity: m * p + 20 });
228+
229+
for (let t = 0; t < thisCols.length - 1; t++) {
230+
const j = t;
231+
const tValues = thisValues.subarray(thisCols[t], thisCols[t + 1]);
232+
const tRows = thisRows.subarray(thisCols[t], thisCols[t + 1]);
233+
let initOther = 0;
234+
for (let o = initOther; o < otherRows.length - 1; o++) {
235+
if (o === j) {
236+
initOther++;
237+
const oValues = otherValues.subarray(otherRows[o], otherRows[o + 1]);
238+
const oCols = otherCols.subarray(otherRows[o], otherRows[o + 1]);
239+
for (let f = 0; f < tValues.length; f++) {
240+
for (let k = 0; k < oValues.length; k++) {
241+
const i = tRows[f];
242+
const l = oCols[k];
243+
result.set(i, l, result.get(i, l) + tValues[f] * oValues[k]);
244+
}
245+
}
246+
} else if (j < o) break;
247+
}
248+
}
249+
return result;
250+
}
251+
190252
kroneckerProduct(other) {
191253
const m = this.rows;
192254
const n = this.columns;
193255
const p = other.rows;
194256
const q = other.columns;
195257

196258
const result = new SparseMatrix(m * p, n * q, {
197-
initialCapacity: this.cardinality * other.cardinality,
259+
initialCapacity: this.cardinality * other.cardinality + 20,
198260
});
199261

200262
const {
@@ -259,19 +321,26 @@ export class SparseMatrix {
259321
return this;
260322
}
261323

262-
getNonZeros() {
324+
//'csr' | 'csc' | undefined
325+
getNonZeros(format) {
263326
const cardinality = this.cardinality;
264-
const rows = new Array(cardinality);
265-
const columns = new Array(cardinality);
266-
const values = new Array(cardinality);
327+
const rows = new Float64Array(cardinality);
328+
const columns = new Float64Array(cardinality);
329+
const values = new Float64Array(cardinality);
267330
let idx = 0;
268331
this.withEachNonZero((i, j, value) => {
269332
rows[idx] = i;
270333
columns[idx] = j;
271334
values[idx] = value;
272335
idx++;
273336
});
274-
return { rows, columns, values };
337+
338+
const cooMatrix = { rows, columns, values };
339+
340+
if (!format) return cooMatrix;
341+
342+
const csrMatrix = cooToCsr(cooMatrix, this.rows);
343+
return format === 'csc' ? csrToCsc(csrMatrix, this.columns) : csrMatrix;
275344
}
276345

277346
setThreshold(newThreshold) {
@@ -452,3 +521,57 @@ function fillTemplateFunction(template, values) {
452521
}
453522
return template;
454523
}
524+
525+
function csrToCsc(csrMatrix, numCols) {
526+
const {
527+
values: csrValues,
528+
columns: csrColIndices,
529+
rows: csrRowPtr,
530+
} = csrMatrix;
531+
// Initialize CSC arrays
532+
const cscValues = new Float64Array(csrValues.length);
533+
const cscRowIndices = new Float64Array(csrValues.length);
534+
const cscColPtr = new Float64Array(numCols + 1);
535+
536+
// Count non-zeros per column
537+
for (let i = 0; i < csrColIndices.length; i++) {
538+
cscColPtr[csrColIndices[i] + 1]++;
539+
}
540+
541+
// Compute column pointers (prefix sum)
542+
for (let i = 1; i <= numCols; i++) {
543+
cscColPtr[i] += cscColPtr[i - 1];
544+
}
545+
546+
// Temporary copy for filling values
547+
const next = cscColPtr.slice();
548+
549+
// Fill CSC arrays
550+
for (let row = 0; row < csrRowPtr.length - 1; row++) {
551+
for (let j = csrRowPtr[row], i = 0; j < csrRowPtr[row + 1]; j++, i++) {
552+
const col = csrColIndices[j];
553+
const pos = next[col];
554+
cscValues[pos] = csrValues[j];
555+
cscRowIndices[pos] = row;
556+
next[col]++;
557+
}
558+
// if (row === 1) break;
559+
}
560+
561+
return { rows: cscRowIndices, columns: cscColPtr, values: cscValues };
562+
}
563+
564+
function cooToCsr(cooMatrix, nbRows = 9) {
565+
const { values, columns, rows } = cooMatrix;
566+
//could not be the same length
567+
const csrRowPtr = new Float64Array(nbRows + 1);
568+
const length = values.length;
569+
let currentRow = rows[0];
570+
for (let index = 0; index < length; ) {
571+
const prev = index;
572+
while (currentRow === rows[index] && index < length) ++index;
573+
csrRowPtr[currentRow + 1] = index;
574+
currentRow += 1;
575+
}
576+
return { rows: csrRowPtr, columns, values };
577+
}

0 commit comments

Comments
 (0)