Skip to content

Commit e4a7002

Browse files
committed
feat: improve performance of mmul and kronecker product
1 parent 27e7e9b commit e4a7002

File tree

2 files changed

+55
-17
lines changed

2 files changed

+55
-17
lines changed

src/__tests__/test.js

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ describe('Sparse Matrix', () => {
3636
expect(m3.cardinality).toBe(1);
3737

3838
expect(m3.get(0, 1)).toBe(2);
39+
expect(m3.to2DArray()).toStrictEqual([
40+
[0, 2],
41+
[0, 0],
42+
]);
3943
});
4044

4145
it('kronecker', () => {

src/index.js

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,32 @@ export class SparseMatrix {
158158
const m = this.rows;
159159
const p = other.columns;
160160

161+
const {
162+
columns: otherCols,
163+
rows: otherRows,
164+
values: otherValues,
165+
} = other.getNonZeros();
166+
const {
167+
columns: thisCols,
168+
rows: thisRows,
169+
values: thisValues,
170+
} = this.getNonZeros();
171+
161172
const result = new SparseMatrix(m, p);
162-
this.withEachNonZero((i, j, v1) => {
163-
other.withEachNonZero((k, l, v2) => {
164-
if (j === k) {
165-
result.set(i, l, result.get(i, l) + v1 * v2);
173+
174+
const nbOtherActive = otherCols.length;
175+
const nbThisActive = thisCols.length;
176+
for (let t = 0; t < nbThisActive; t++) {
177+
const i = thisRows[t];
178+
const j = thisCols[t];
179+
for (let o = 0; o < nbOtherActive; o++) {
180+
if (j === otherRows[o]) {
181+
const l = otherCols[o];
182+
result.set(i, l, result.get(i, l) + otherValues[o] * thisValues[t]);
166183
}
167-
});
168-
});
184+
}
185+
}
186+
169187
return result;
170188
}
171189

@@ -179,13 +197,31 @@ export class SparseMatrix {
179197
initialCapacity: this.cardinality * other.cardinality,
180198
});
181199

182-
this.withEachNonZero((i, j, v1) => {
183-
const pi = p * i;
184-
const qj = q * j;
185-
other.withEachNonZero((k, l, v2) => {
186-
result.set(pi + k, qj + l, v1 * v2);
187-
});
188-
});
200+
const {
201+
columns: otherCols,
202+
rows: otherRows,
203+
values: otherValues,
204+
} = other.getNonZeros();
205+
const {
206+
columns: thisCols,
207+
rows: thisRows,
208+
values: thisValues,
209+
} = this.getNonZeros();
210+
211+
const nbOtherActive = otherCols.length;
212+
const nbThisActive = thisCols.length;
213+
for (let t = 0; t < nbThisActive; t++) {
214+
const pi = p * thisRows[t];
215+
const qj = q * thisCols[t];
216+
for (let o = 0; o < nbOtherActive; o++) {
217+
result.set(
218+
pi + otherRows[o],
219+
qj + otherCols[o],
220+
otherValues[o] * thisValues[t],
221+
);
222+
}
223+
}
224+
189225
return result;
190226
}
191227

@@ -229,12 +265,11 @@ export class SparseMatrix {
229265
const columns = new Array(cardinality);
230266
const values = new Array(cardinality);
231267
let idx = 0;
232-
this.forEachNonZero((i, j, value) => {
268+
this.withEachNonZero((i, j, value) => {
233269
rows[idx] = i;
234270
columns[idx] = j;
235271
values[idx] = value;
236272
idx++;
237-
return value;
238273
});
239274
return { rows, columns, values };
240275
}
@@ -254,9 +289,8 @@ export class SparseMatrix {
254289
let trans = new SparseMatrix(this.columns, this.rows, {
255290
initialCapacity: this.cardinality,
256291
});
257-
this.forEachNonZero((i, j, value) => {
292+
this.withEachNonZero((i, j, value) => {
258293
trans.set(j, i, value);
259-
return value;
260294
});
261295
return trans;
262296
}

0 commit comments

Comments
 (0)