@@ -152,14 +152,29 @@ export class SparseMatrix {
152
152
} else {
153
153
this . elements . set ( row * this . columns + column , value ) ;
154
154
}
155
+
155
156
return this ;
156
157
}
157
158
158
159
/**
159
160
* @param {SparseMatrix } other
160
161
* @returns {SparseMatrix }
161
162
*/
162
- mmul ( other ) {
163
+ mulNew ( other ) {
164
+ if ( typeof other !== 'number' ) {
165
+ throw new RangeError ( 'the argument should be a number' ) ;
166
+ }
167
+
168
+ // if (this.cardinality / this.columns / this.rows > 0.1)
169
+ this . withEachNonZero ( ( i , j , v ) => {
170
+ // return v * other;
171
+ this . set ( i , j , v * other ) ;
172
+ } ) ;
173
+
174
+ return this ;
175
+ }
176
+
177
+ mmulNew ( other ) {
163
178
if ( this . columns !== other . rows ) {
164
179
// eslint-disable-next-line no-console
165
180
console . warn (
@@ -181,7 +196,7 @@ export class SparseMatrix {
181
196
values : thisValues ,
182
197
} = this . getNonZeros ( ) ;
183
198
184
- const result = new SparseMatrix ( m , p ) ;
199
+ const result = new SparseMatrix ( m , p , { initialCapacity : m * p } ) ;
185
200
186
201
const nbOtherActive = otherCols . length ;
187
202
const nbThisActive = thisCols . length ;
@@ -199,6 +214,53 @@ export class SparseMatrix {
199
214
return result ;
200
215
}
201
216
217
+ mmul ( other ) {
218
+ if ( this . columns !== other . rows ) {
219
+ // eslint-disable-next-line no-console
220
+ console . warn (
221
+ 'Number of columns of left matrix are not equal to number of rows of right matrix.' ,
222
+ ) ;
223
+ }
224
+
225
+ const m = this . rows ;
226
+ const p = other . columns ;
227
+
228
+ const {
229
+ columns : otherCols ,
230
+ rows : otherRows ,
231
+ values : otherValues ,
232
+ } = other . getNonZeros ( 'csr' ) ;
233
+ const {
234
+ columns : thisCols ,
235
+ rows : thisRows ,
236
+ values : thisValues ,
237
+ } = this . getNonZeros ( 'csc' ) ;
238
+
239
+ const result = new SparseMatrix ( m , p , { initialCapacity : m * p + 20 } ) ;
240
+
241
+ for ( let t = 0 ; t < thisCols . length - 1 ; t ++ ) {
242
+ const j = t ;
243
+ const tValues = thisValues . subarray ( thisCols [ t ] , thisCols [ t + 1 ] ) ;
244
+ const tRows = thisRows . subarray ( thisCols [ t ] , thisCols [ t + 1 ] ) ;
245
+ let initOther = 0 ;
246
+ for ( let o = initOther ; o < otherRows . length - 1 ; o ++ ) {
247
+ if ( o === j ) {
248
+ initOther ++ ;
249
+ const oValues = otherValues . subarray ( otherRows [ o ] , otherRows [ o + 1 ] ) ;
250
+ const oCols = otherCols . subarray ( otherRows [ o ] , otherRows [ o + 1 ] ) ;
251
+ for ( let f = 0 ; f < tValues . length ; f ++ ) {
252
+ for ( let k = 0 ; k < oValues . length ; k ++ ) {
253
+ const i = tRows [ f ] ;
254
+ const l = oCols [ k ] ;
255
+ result . set ( i , l , result . get ( i , l ) + tValues [ f ] * oValues [ k ] ) ;
256
+ }
257
+ }
258
+ } else if ( j < o ) break ;
259
+ }
260
+ }
261
+ return result ;
262
+ }
263
+
202
264
/**
203
265
* @param {SparseMatrix } other
204
266
* @returns {SparseMatrix }
@@ -210,7 +272,7 @@ export class SparseMatrix {
210
272
const q = other . columns ;
211
273
212
274
const result = new SparseMatrix ( m * p , n * q , {
213
- initialCapacity : this . cardinality * other . cardinality ,
275
+ initialCapacity : this . cardinality * other . cardinality + 20 ,
214
276
} ) ;
215
277
216
278
const {
@@ -285,22 +347,30 @@ export class SparseMatrix {
285
347
return this ;
286
348
}
287
349
288
- getNonZeros ( ) {
350
+ //'csr' | 'csc' | undefined
351
+ getNonZeros ( format ) {
289
352
const cardinality = this . cardinality ;
290
- /** @type {number[] } */
291
- const rows = new Array ( cardinality ) ;
292
- /** @type {number[] } */
293
- const columns = new Array ( cardinality ) ;
294
- /** @type {number[] } */
295
- const values = new Array ( cardinality ) ;
353
+ /** @type {Float64Array } */
354
+ const rows = new Float64Array ( cardinality ) ;
355
+ /** @type {Float64Array } */
356
+ const columns = new Float64Array ( cardinality ) ;
357
+ /** @type {Float64Array } */
358
+ const values = new Float64Array ( cardinality ) ;
359
+
296
360
let idx = 0 ;
297
361
this . withEachNonZero ( ( i , j , value ) => {
298
362
rows [ idx ] = i ;
299
363
columns [ idx ] = j ;
300
364
values [ idx ] = value ;
301
365
idx ++ ;
302
366
} ) ;
303
- return { rows, columns, values } ;
367
+
368
+ const cooMatrix = { rows, columns, values } ;
369
+
370
+ if ( ! format ) return cooMatrix ;
371
+
372
+ const csrMatrix = cooToCsr ( cooMatrix , this . rows ) ;
373
+ return format === 'csc' ? csrToCsc ( csrMatrix , this . columns ) : csrMatrix ;
304
374
}
305
375
306
376
/**
@@ -1433,3 +1503,57 @@ SparseMatrix.prototype.klass = 'Matrix';
1433
1503
1434
1504
SparseMatrix . identity = SparseMatrix . eye ;
1435
1505
SparseMatrix . prototype . tensorProduct = SparseMatrix . prototype . kroneckerProduct ;
1506
+
1507
+ function csrToCsc ( csrMatrix , numCols ) {
1508
+ const {
1509
+ values : csrValues ,
1510
+ columns : csrColIndices ,
1511
+ rows : csrRowPtr ,
1512
+ } = csrMatrix ;
1513
+ // Initialize CSC arrays
1514
+ const cscValues = new Float64Array ( csrValues . length ) ;
1515
+ const cscRowIndices = new Float64Array ( csrValues . length ) ;
1516
+ const cscColPtr = new Float64Array ( numCols + 1 ) ;
1517
+
1518
+ // Count non-zeros per column
1519
+ for ( let i = 0 ; i < csrColIndices . length ; i ++ ) {
1520
+ cscColPtr [ csrColIndices [ i ] + 1 ] ++ ;
1521
+ }
1522
+
1523
+ // Compute column pointers (prefix sum)
1524
+ for ( let i = 1 ; i <= numCols ; i ++ ) {
1525
+ cscColPtr [ i ] += cscColPtr [ i - 1 ] ;
1526
+ }
1527
+
1528
+ // Temporary copy for filling values
1529
+ const next = cscColPtr . slice ( ) ;
1530
+
1531
+ // Fill CSC arrays
1532
+ for ( let row = 0 ; row < csrRowPtr . length - 1 ; row ++ ) {
1533
+ for ( let j = csrRowPtr [ row ] , i = 0 ; j < csrRowPtr [ row + 1 ] ; j ++ , i ++ ) {
1534
+ const col = csrColIndices [ j ] ;
1535
+ const pos = next [ col ] ;
1536
+ cscValues [ pos ] = csrValues [ j ] ;
1537
+ cscRowIndices [ pos ] = row ;
1538
+ next [ col ] ++ ;
1539
+ }
1540
+ // if (row === 1) break;
1541
+ }
1542
+
1543
+ return { rows : cscRowIndices , columns : cscColPtr , values : cscValues } ;
1544
+ }
1545
+
1546
+ function cooToCsr ( cooMatrix , nbRows = 9 ) {
1547
+ const { values, columns, rows } = cooMatrix ;
1548
+ //could not be the same length
1549
+ const csrRowPtr = new Float64Array ( nbRows + 1 ) ;
1550
+ const length = values . length ;
1551
+ let currentRow = rows [ 0 ] ;
1552
+ for ( let index = 0 ; index < length ; ) {
1553
+ const prev = index ;
1554
+ while ( currentRow === rows [ index ] && index < length ) ++ index ;
1555
+ csrRowPtr [ currentRow + 1 ] = index ;
1556
+ currentRow += 1 ;
1557
+ }
1558
+ return { rows : csrRowPtr , columns, values } ;
1559
+ }
0 commit comments