|
1 | 1 | import HashTable from 'ml-hash-table';
|
2 | 2 |
|
3 | 3 | import { cooToCsr } from './utils/cooToCsr.js';
|
| 4 | +import { mmulLowDensity } from './utils/mmulLowDensity.js'; |
| 5 | +import { mmulMediumDensity } from './utils/mmulMediumDensity.js'; |
| 6 | +import { mmulSmall } from './utils/mmulSmall.js'; |
4 | 7 |
|
5 | 8 | /** @typedef {(row: number, column: number, value: number) => void} WithEachNonZeroCallback */
|
6 | 9 | /** @typedef {(row: number, column: number, value: number) => number | false} ForEachNonZeroCallback */
|
@@ -170,117 +173,16 @@ export class SparseMatrix {
|
170 | 173 | 'Number of columns of left matrix are not equal to number of rows of right matrix.',
|
171 | 174 | );
|
172 | 175 | }
|
| 176 | + |
173 | 177 | if (this.cardinality < 42 && other.cardinality < 42) {
|
174 |
| - return this._mmulSmall(other); |
| 178 | + return mmulSmall(this, other); |
175 | 179 | } else if (other.rows > 100 && other.cardinality < 100) {
|
176 |
| - return this._mmulLowDensity(other); |
| 180 | + return mmulLowDensity(this, other); |
177 | 181 | }
|
178 |
| - return this._mmulMediumDensity(other); |
179 |
| - } |
180 |
| - |
181 |
| - /** |
182 |
| - * Matrix multiplication optimized for very small matrices (both cardinalities < 42). |
183 |
| - * |
184 |
| - * @private |
185 |
| - * @param {SparseMatrix} other - The right-hand side matrix to multiply with. |
186 |
| - * @returns {SparseMatrix} - The resulting matrix after multiplication. |
187 |
| - */ |
188 |
| - _mmulSmall(other) { |
189 |
| - const m = this.rows; |
190 |
| - const p = other.columns; |
191 |
| - const { |
192 |
| - columns: otherCols, |
193 |
| - rows: otherRows, |
194 |
| - values: otherValues, |
195 |
| - } = other.getNonZeros(); |
196 |
| - |
197 |
| - const nbOtherActive = otherCols.length; |
198 |
| - const result = new SparseMatrix(m, p); |
199 |
| - this.withEachNonZero((i, j, v1) => { |
200 |
| - for (let o = 0; o < nbOtherActive; o++) { |
201 |
| - if (j === otherRows[o]) { |
202 |
| - const l = otherCols[o]; |
203 |
| - result.set(i, l, result.get(i, l) + otherValues[o] * v1); |
204 |
| - } |
205 |
| - } |
206 |
| - }); |
207 |
| - return result; |
208 |
| - } |
209 |
| - |
210 |
| - /** |
211 |
| - * Matrix multiplication optimized for low-density right-hand side matrices (other.rows > 100 and other.cardinality < 100). |
212 |
| - * |
213 |
| - * @private |
214 |
| - * @param {SparseMatrix} other - The right-hand side matrix to multiply with. |
215 |
| - * @returns {SparseMatrix} - The resulting matrix after multiplication. |
216 |
| - */ |
217 |
| - _mmulLowDensity(other) { |
218 |
| - const m = this.rows; |
219 |
| - const p = other.columns; |
220 |
| - const { |
221 |
| - columns: otherCols, |
222 |
| - rows: otherRows, |
223 |
| - values: otherValues, |
224 |
| - } = other.getNonZeros(); |
225 |
| - const { |
226 |
| - columns: thisCols, |
227 |
| - rows: thisRows, |
228 |
| - values: thisValues, |
229 |
| - } = this.getNonZeros(); |
230 | 182 |
|
231 |
| - const result = new SparseMatrix(m, p); |
232 |
| - const nbOtherActive = otherCols.length; |
233 |
| - const nbThisActive = thisCols.length; |
234 |
| - for (let t = 0; t < nbThisActive; t++) { |
235 |
| - const i = thisRows[t]; |
236 |
| - const j = thisCols[t]; |
237 |
| - for (let o = 0; o < nbOtherActive; o++) { |
238 |
| - if (j === otherRows[o]) { |
239 |
| - const l = otherCols[o]; |
240 |
| - result.set(i, l, result.get(i, l) + otherValues[o] * thisValues[t]); |
241 |
| - } |
242 |
| - } |
243 |
| - } |
244 |
| - // console.log(result.cardinality); |
245 |
| - return result; |
| 183 | + return mmulMediumDensity(this, other); |
246 | 184 | }
|
247 | 185 |
|
248 |
| - /** |
249 |
| - * Matrix multiplication for medium-density matrices using CSR format for the right-hand side. |
250 |
| - * |
251 |
| - * @private |
252 |
| - * @param {SparseMatrix} other - The right-hand side matrix to multiply with. |
253 |
| - * @returns {SparseMatrix} - The resulting matrix after multiplication. |
254 |
| - */ |
255 |
| - _mmulMediumDensity(other) { |
256 |
| - const { |
257 |
| - columns: thisCols, |
258 |
| - rows: thisRows, |
259 |
| - values: thisValues, |
260 |
| - } = this.getNonZeros(); |
261 |
| - const { |
262 |
| - columns: otherCols, |
263 |
| - rows: otherRows, |
264 |
| - values: otherValues, |
265 |
| - } = other.getNonZeros({ csr: true }); |
266 |
| - |
267 |
| - const m = this.rows; |
268 |
| - const p = other.columns; |
269 |
| - const result = new SparseMatrix(m, p); |
270 |
| - const nbThisActive = thisCols.length; |
271 |
| - for (let t = 0; t < nbThisActive; t++) { |
272 |
| - const i = thisRows[t]; |
273 |
| - const j = thisCols[t]; |
274 |
| - const oStart = otherRows[j]; |
275 |
| - const oEnd = otherRows[j + 1]; |
276 |
| - for (let k = oStart; k < oEnd; k++) { |
277 |
| - const l = otherCols[k]; |
278 |
| - result.set(i, l, result.get(i, l) + otherValues[k] * thisValues[t]); |
279 |
| - } |
280 |
| - } |
281 |
| - |
282 |
| - return result; |
283 |
| - } |
284 | 186 | /**
|
285 | 187 | * @param {SparseMatrix} other
|
286 | 188 | * @returns {SparseMatrix}
|
|
0 commit comments