Skip to content

Commit 099f24f

Browse files
committed
feat: add new matrix addition method and improve performance of mmul and kroneckerProduct method
1 parent 9361aa3 commit 099f24f

File tree

2 files changed

+46
-13
lines changed

2 files changed

+46
-13
lines changed

src/__tests__/index.test.js

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,24 @@ import { describe, expect, it } from 'vitest';
33
import { SparseMatrix } from '../index.js';
44

55
describe('Sparse Matrix', () => {
6+
it('add', () => {
7+
let m1 = new SparseMatrix([
8+
[2, 0, 1],
9+
[0, 0, 3],
10+
[2, 0, 1],
11+
]);
12+
let m2 = new SparseMatrix([
13+
[0, 1, 5],
14+
[2, 0, 0],
15+
[-2, 0, -1],
16+
]);
17+
let m3 = m1.add(m2).to2DArray();
18+
expect(m3).toStrictEqual([
19+
[2, 1, 6],
20+
[2, 0, 3],
21+
[0, 0, 0],
22+
]);
23+
});
624
it('mmul', () => {
725
let m1 = new SparseMatrix([
826
[2, 0, 1],

src/index.js

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ export class SparseMatrix {
1818

1919
if (Array.isArray(rows)) {
2020
const matrix = rows;
21-
rows = matrix.length;
22-
options = columns || {};
21+
const nbRows = matrix.length;
22+
const nbColmuns = matrix[0].length;
23+
options = columns || { initialCapacity: nbRows * nbColmuns };
2324
columns = matrix[0].length;
24-
this._init(rows, columns, new HashTable(options), options.threshold);
25-
for (let i = 0; i < rows; i++) {
25+
this._init(nbRows, columns, new HashTable(options), options.threshold);
26+
for (let i = 0; i < nbRows; i++) {
2627
for (let j = 0; j < columns; j++) {
2728
let value = matrix[i][j];
2829
if (this.threshold && Math.abs(value) < this.threshold) value = 0;
@@ -31,6 +32,7 @@ export class SparseMatrix {
3132
}
3233
}
3334
}
35+
this.elements.maybeShrinkCapacity();
3436
} else {
3537
this._init(rows, columns, new HashTable(options), options.threshold);
3638
}
@@ -169,14 +171,12 @@ export class SparseMatrix {
169171
const p = other.columns;
170172

171173
const result = new SparseMatrix(m, p);
172-
this.forEachNonZero((i, j, v1) => {
173-
other.forEachNonZero((k, l, v2) => {
174+
this.withEachNonZero((i, j, v1) => {
175+
other.withEachNonZero((k, l, v2) => {
174176
if (j === k) {
175177
result.set(i, l, result.get(i, l) + v1 * v2);
176178
}
177-
return v2;
178179
});
179-
return v1;
180180
});
181181
return result;
182182
}
@@ -194,16 +194,31 @@ export class SparseMatrix {
194194
const result = new SparseMatrix(m * p, n * q, {
195195
initialCapacity: this.cardinality * other.cardinality,
196196
});
197-
this.forEachNonZero((i, j, v1) => {
198-
other.forEachNonZero((k, l, v2) => {
199-
result.set(p * i + k, q * j + l, v1 * v2);
200-
return v2;
197+
198+
this.withEachNonZero((i, j, v1) => {
199+
const pi = p * i;
200+
const qj = q * j;
201+
other.withEachNonZero((k, l, v2) => {
202+
result.set(pi + k, qj + l, v1 * v2);
201203
});
202-
return v1;
203204
});
204205
return result;
205206
}
206207

208+
withEachNonZero(callback) {
209+
const { state, table, values } = this.elements;
210+
const nbStates = state.length;
211+
const activeIndex = [];
212+
for (let i = 0; i < nbStates; i++) {
213+
if (state[i] === 1) activeIndex.push(i);
214+
}
215+
const columns = this.columns;
216+
for (const i of activeIndex) {
217+
const key = table[i];
218+
callback((key / columns) | 0, key % columns, values[i]);
219+
}
220+
}
221+
207222
/**
208223
* Calls `callback` for each value in the matrix that is not zero.
209224
* The callback can return:

0 commit comments

Comments
 (0)