Skip to content

Commit e287ca1

Browse files
committed
feat: move internal specialized matrix multiplication as pure functions
1 parent 10aee91 commit e287ca1

File tree

4 files changed

+122
-105
lines changed

4 files changed

+122
-105
lines changed

src/index.js

Lines changed: 7 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import HashTable from 'ml-hash-table';
22

33
import { cooToCsr } from './utils/cooToCsr.js';
4+
import { mmulLowDensity } from './utils/mmulLowDensity.js';
5+
import { mmulMediumDensity } from './utils/mmulMediumDensity.js';
6+
import { mmulSmall } from './utils/mmulSmall.js';
47

58
/** @typedef {(row: number, column: number, value: number) => void} WithEachNonZeroCallback */
69
/** @typedef {(row: number, column: number, value: number) => number | false} ForEachNonZeroCallback */
@@ -170,117 +173,16 @@ export class SparseMatrix {
170173
'Number of columns of left matrix are not equal to number of rows of right matrix.',
171174
);
172175
}
176+
173177
if (this.cardinality < 42 && other.cardinality < 42) {
174-
return this._mmulSmall(other);
178+
return mmulSmall(this, other);
175179
} else if (other.rows > 100 && other.cardinality < 100) {
176-
return this._mmulLowDensity(other);
180+
return mmulLowDensity(this, other);
177181
}
178-
return this._mmulMediumDensity(other);
179-
}
180-
181-
/**
182-
* Matrix multiplication optimized for very small matrices (both cardinalities < 42).
183-
*
184-
* @private
185-
* @param {SparseMatrix} other - The right-hand side matrix to multiply with.
186-
* @returns {SparseMatrix} - The resulting matrix after multiplication.
187-
*/
188-
_mmulSmall(other) {
189-
const m = this.rows;
190-
const p = other.columns;
191-
const {
192-
columns: otherCols,
193-
rows: otherRows,
194-
values: otherValues,
195-
} = other.getNonZeros();
196-
197-
const nbOtherActive = otherCols.length;
198-
const result = new SparseMatrix(m, p);
199-
this.withEachNonZero((i, j, v1) => {
200-
for (let o = 0; o < nbOtherActive; o++) {
201-
if (j === otherRows[o]) {
202-
const l = otherCols[o];
203-
result.set(i, l, result.get(i, l) + otherValues[o] * v1);
204-
}
205-
}
206-
});
207-
return result;
208-
}
209-
210-
/**
211-
* Matrix multiplication optimized for low-density right-hand side matrices (other.rows > 100 and other.cardinality < 100).
212-
*
213-
* @private
214-
* @param {SparseMatrix} other - The right-hand side matrix to multiply with.
215-
* @returns {SparseMatrix} - The resulting matrix after multiplication.
216-
*/
217-
_mmulLowDensity(other) {
218-
const m = this.rows;
219-
const p = other.columns;
220-
const {
221-
columns: otherCols,
222-
rows: otherRows,
223-
values: otherValues,
224-
} = other.getNonZeros();
225-
const {
226-
columns: thisCols,
227-
rows: thisRows,
228-
values: thisValues,
229-
} = this.getNonZeros();
230182

231-
const result = new SparseMatrix(m, p);
232-
const nbOtherActive = otherCols.length;
233-
const nbThisActive = thisCols.length;
234-
for (let t = 0; t < nbThisActive; t++) {
235-
const i = thisRows[t];
236-
const j = thisCols[t];
237-
for (let o = 0; o < nbOtherActive; o++) {
238-
if (j === otherRows[o]) {
239-
const l = otherCols[o];
240-
result.set(i, l, result.get(i, l) + otherValues[o] * thisValues[t]);
241-
}
242-
}
243-
}
244-
// console.log(result.cardinality);
245-
return result;
183+
return mmulMediumDensity(this, other);
246184
}
247185

248-
/**
249-
* Matrix multiplication for medium-density matrices using CSR format for the right-hand side.
250-
*
251-
* @private
252-
* @param {SparseMatrix} other - The right-hand side matrix to multiply with.
253-
* @returns {SparseMatrix} - The resulting matrix after multiplication.
254-
*/
255-
_mmulMediumDensity(other) {
256-
const {
257-
columns: thisCols,
258-
rows: thisRows,
259-
values: thisValues,
260-
} = this.getNonZeros();
261-
const {
262-
columns: otherCols,
263-
rows: otherRows,
264-
values: otherValues,
265-
} = other.getNonZeros({ csr: true });
266-
267-
const m = this.rows;
268-
const p = other.columns;
269-
const result = new SparseMatrix(m, p);
270-
const nbThisActive = thisCols.length;
271-
for (let t = 0; t < nbThisActive; t++) {
272-
const i = thisRows[t];
273-
const j = thisCols[t];
274-
const oStart = otherRows[j];
275-
const oEnd = otherRows[j + 1];
276-
for (let k = oStart; k < oEnd; k++) {
277-
const l = otherCols[k];
278-
result.set(i, l, result.get(i, l) + otherValues[k] * thisValues[t]);
279-
}
280-
}
281-
282-
return result;
283-
}
284186
/**
285187
* @param {SparseMatrix} other
286188
* @returns {SparseMatrix}

src/utils/mmulLowDensity.js

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import { SparseMatrix } from '../index.js';
2+
3+
/**
4+
* Multiplies two sparse matrices, optimized for cases where the right-hand side matrix
5+
* has low density (number of rows > 100 and cardinality < 100).
6+
*
7+
* @private
8+
* @param {SparseMatrix} left - The left-hand side matrix.
9+
* @param {SparseMatrix} right - The right-hand side matrix to multiply with.
10+
* @returns {SparseMatrix} The resulting matrix after multiplication.
11+
*/
12+
export function mmulLowDensity(left, right) {
13+
const {
14+
columns: otherCols,
15+
rows: otherRows,
16+
values: otherValues,
17+
} = right.getNonZeros();
18+
const {
19+
columns: thisCols,
20+
rows: thisRows,
21+
values: thisValues,
22+
} = left.getNonZeros();
23+
24+
const m = left.rows;
25+
const p = right.columns;
26+
const output = new SparseMatrix(m, p);
27+
28+
const nbOtherActive = otherCols.length;
29+
const nbThisActive = thisCols.length;
30+
for (let t = 0; t < nbThisActive; t++) {
31+
const i = thisRows[t];
32+
const j = thisCols[t];
33+
for (let o = 0; o < nbOtherActive; o++) {
34+
if (j === otherRows[o]) {
35+
const l = otherCols[o];
36+
output.set(i, l, output.get(i, l) + otherValues[o] * thisValues[t]);
37+
}
38+
}
39+
}
40+
return output;
41+
}

src/utils/mmulMediumDensity.js

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import { SparseMatrix } from '../index.js';
2+
3+
/**
4+
* Multiplies two sparse matrices where the right-hand side matrix is of medium density.
5+
* Uses CSR format for the right-hand side for efficient multiplication.
6+
*
7+
* @private
8+
* @param {SparseMatrix} left - The left-hand side matrix.
9+
* @param {SparseMatrix} right - The right-hand side matrix to multiply with (in CSR format).
10+
* @returns {SparseMatrix} The resulting matrix after multiplication.
11+
*/
12+
13+
export function mmulMediumDensity(left, right) {
14+
const m = left.rows;
15+
const p = right.columns;
16+
const result = new SparseMatrix(m, p);
17+
const {
18+
columns: thisCols,
19+
rows: thisRows,
20+
values: thisValues,
21+
} = left.getNonZeros();
22+
const {
23+
columns: otherCols,
24+
rows: otherRows,
25+
values: otherValues,
26+
} = right.getNonZeros({ csr: true });
27+
28+
const nbThisActive = thisCols.length;
29+
for (let t = 0; t < nbThisActive; t++) {
30+
const i = thisRows[t];
31+
const j = thisCols[t];
32+
const oStart = otherRows[j];
33+
const oEnd = otherRows[j + 1];
34+
for (let k = oStart; k < oEnd; k++) {
35+
const l = otherCols[k];
36+
result.set(i, l, result.get(i, l) + otherValues[k] * thisValues[t]);
37+
}
38+
}
39+
40+
return result;
41+
}

src/utils/mmulSmall.js

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import { SparseMatrix } from '../index.js';
2+
3+
/**
4+
* Multiplies two very small sparse matrices (both with cardinalities < 42).
5+
*
6+
* @private
7+
* @param {SparseMatrix} left - The left-hand side matrix.
8+
* @param {SparseMatrix} right - The right-hand side matrix to multiply with.
9+
* @returns {SparseMatrix} The resulting matrix after multiplication.
10+
*/
11+
export function mmulSmall(left, right) {
12+
const {
13+
columns: otherCols,
14+
rows: otherRows,
15+
values: otherValues,
16+
} = right.getNonZeros();
17+
18+
const nbOtherActive = otherCols.length;
19+
20+
const m = left.rows;
21+
const p = right.columns;
22+
const output = new SparseMatrix(m, p);
23+
24+
left.withEachNonZero((i, j, v1) => {
25+
for (let o = 0; o < nbOtherActive; o++) {
26+
if (j === otherRows[o]) {
27+
const l = otherCols[o];
28+
output.set(i, l, output.get(i, l) + otherValues[o] * v1);
29+
}
30+
}
31+
});
32+
return output;
33+
}

0 commit comments

Comments
 (0)