Skip to content

Commit 104e3e1

Browse files
committed
feat: improve performance of mmul and kronecker product
1 parent 099f24f commit 104e3e1

File tree

2 files changed

+55
-17
lines changed

2 files changed

+55
-17
lines changed

src/__tests__/index.test.js

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ describe('Sparse Matrix', () => {
3838
expect(m3.cardinality).toBe(1);
3939

4040
expect(m3.get(0, 1)).toBe(2);
41+
expect(m3.to2DArray()).toStrictEqual([
42+
[0, 2],
43+
[0, 0],
44+
]);
4145
});
4246

4347
it('kronecker', () => {

src/index.js

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -170,14 +170,32 @@ export class SparseMatrix {
170170
const m = this.rows;
171171
const p = other.columns;
172172

173+
const {
174+
columns: otherCols,
175+
rows: otherRows,
176+
values: otherValues,
177+
} = other.getNonZeros();
178+
const {
179+
columns: thisCols,
180+
rows: thisRows,
181+
values: thisValues,
182+
} = this.getNonZeros();
183+
173184
const result = new SparseMatrix(m, p);
174-
this.withEachNonZero((i, j, v1) => {
175-
other.withEachNonZero((k, l, v2) => {
176-
if (j === k) {
177-
result.set(i, l, result.get(i, l) + v1 * v2);
185+
186+
const nbOtherActive = otherCols.length;
187+
const nbThisActive = thisCols.length;
188+
for (let t = 0; t < nbThisActive; t++) {
189+
const i = thisRows[t];
190+
const j = thisCols[t];
191+
for (let o = 0; o < nbOtherActive; o++) {
192+
if (j === otherRows[o]) {
193+
const l = otherCols[o];
194+
result.set(i, l, result.get(i, l) + otherValues[o] * thisValues[t]);
178195
}
179-
});
180-
});
196+
}
197+
}
198+
181199
return result;
182200
}
183201

@@ -195,13 +213,31 @@ export class SparseMatrix {
195213
initialCapacity: this.cardinality * other.cardinality,
196214
});
197215

198-
this.withEachNonZero((i, j, v1) => {
199-
const pi = p * i;
200-
const qj = q * j;
201-
other.withEachNonZero((k, l, v2) => {
202-
result.set(pi + k, qj + l, v1 * v2);
203-
});
204-
});
216+
const {
217+
columns: otherCols,
218+
rows: otherRows,
219+
values: otherValues,
220+
} = other.getNonZeros();
221+
const {
222+
columns: thisCols,
223+
rows: thisRows,
224+
values: thisValues,
225+
} = this.getNonZeros();
226+
227+
const nbOtherActive = otherCols.length;
228+
const nbThisActive = thisCols.length;
229+
for (let t = 0; t < nbThisActive; t++) {
230+
const pi = p * thisRows[t];
231+
const qj = q * thisCols[t];
232+
for (let o = 0; o < nbOtherActive; o++) {
233+
result.set(
234+
pi + otherRows[o],
235+
qj + otherCols[o],
236+
otherValues[o] * thisValues[t],
237+
);
238+
}
239+
}
240+
205241
return result;
206242
}
207243

@@ -258,12 +294,11 @@ export class SparseMatrix {
258294
/** @type {number[]} */
259295
const values = new Array(cardinality);
260296
let idx = 0;
261-
this.forEachNonZero((i, j, value) => {
297+
this.withEachNonZero((i, j, value) => {
262298
rows[idx] = i;
263299
columns[idx] = j;
264300
values[idx] = value;
265301
idx++;
266-
return value;
267302
});
268303
return { rows, columns, values };
269304
}
@@ -287,9 +322,8 @@ export class SparseMatrix {
287322
let trans = new SparseMatrix(this.columns, this.rows, {
288323
initialCapacity: this.cardinality,
289324
});
290-
this.forEachNonZero((i, j, value) => {
325+
this.withEachNonZero((i, j, value) => {
291326
trans.set(j, i, value);
292-
return value;
293327
});
294328
return trans;
295329
}

0 commit comments

Comments
 (0)