@@ -131,6 +131,7 @@ export class SparseMatrix {
131
131
} else {
132
132
this . elements . set ( row * this . columns + column , value ) ;
133
133
}
134
+
134
135
return this ;
135
136
}
136
137
@@ -147,7 +148,21 @@ export class SparseMatrix {
147
148
return this ;
148
149
}
149
150
150
- mmul ( other ) {
151
+ mulNew ( other ) {
152
+ if ( typeof other !== 'number' ) {
153
+ throw new RangeError ( 'the argument should be a number' ) ;
154
+ }
155
+
156
+ // if (this.cardinality / this.columns / this.rows > 0.1)
157
+ this . withEachNonZero ( ( i , j , v ) => {
158
+ // return v * other;
159
+ this . set ( i , j , v * other ) ;
160
+ } ) ;
161
+
162
+ return this ;
163
+ }
164
+
165
+ mmulNew ( other ) {
151
166
if ( this . columns !== other . rows ) {
152
167
// eslint-disable-next-line no-console
153
168
console . warn (
@@ -169,7 +184,7 @@ export class SparseMatrix {
169
184
values : thisValues ,
170
185
} = this . getNonZeros ( ) ;
171
186
172
- const result = new SparseMatrix ( m , p ) ;
187
+ const result = new SparseMatrix ( m , p , { initialCapacity : m * p } ) ;
173
188
174
189
const nbOtherActive = otherCols . length ;
175
190
const nbThisActive = thisCols . length ;
@@ -187,14 +202,61 @@ export class SparseMatrix {
187
202
return result ;
188
203
}
189
204
205
+ mmul ( other ) {
206
+ if ( this . columns !== other . rows ) {
207
+ // eslint-disable-next-line no-console
208
+ console . warn (
209
+ 'Number of columns of left matrix are not equal to number of rows of right matrix.' ,
210
+ ) ;
211
+ }
212
+
213
+ const m = this . rows ;
214
+ const p = other . columns ;
215
+
216
+ const {
217
+ columns : otherCols ,
218
+ rows : otherRows ,
219
+ values : otherValues ,
220
+ } = other . getNonZeros ( 'csr' ) ;
221
+ const {
222
+ columns : thisCols ,
223
+ rows : thisRows ,
224
+ values : thisValues ,
225
+ } = this . getNonZeros ( 'csc' ) ;
226
+
227
+ const result = new SparseMatrix ( m , p , { initialCapacity : m * p + 20 } ) ;
228
+
229
+ for ( let t = 0 ; t < thisCols . length - 1 ; t ++ ) {
230
+ const j = t ;
231
+ const tValues = thisValues . subarray ( thisCols [ t ] , thisCols [ t + 1 ] ) ;
232
+ const tRows = thisRows . subarray ( thisCols [ t ] , thisCols [ t + 1 ] ) ;
233
+ let initOther = 0 ;
234
+ for ( let o = initOther ; o < otherRows . length - 1 ; o ++ ) {
235
+ if ( o === j ) {
236
+ initOther ++ ;
237
+ const oValues = otherValues . subarray ( otherRows [ o ] , otherRows [ o + 1 ] ) ;
238
+ const oCols = otherCols . subarray ( otherRows [ o ] , otherRows [ o + 1 ] ) ;
239
+ for ( let f = 0 ; f < tValues . length ; f ++ ) {
240
+ for ( let k = 0 ; k < oValues . length ; k ++ ) {
241
+ const i = tRows [ f ] ;
242
+ const l = oCols [ k ] ;
243
+ result . set ( i , l , result . get ( i , l ) + tValues [ f ] * oValues [ k ] ) ;
244
+ }
245
+ }
246
+ } else if ( j < o ) break ;
247
+ }
248
+ }
249
+ return result ;
250
+ }
251
+
190
252
kroneckerProduct ( other ) {
191
253
const m = this . rows ;
192
254
const n = this . columns ;
193
255
const p = other . rows ;
194
256
const q = other . columns ;
195
257
196
258
const result = new SparseMatrix ( m * p , n * q , {
197
- initialCapacity : this . cardinality * other . cardinality ,
259
+ initialCapacity : this . cardinality * other . cardinality + 20 ,
198
260
} ) ;
199
261
200
262
const {
@@ -259,19 +321,26 @@ export class SparseMatrix {
259
321
return this ;
260
322
}
261
323
262
- getNonZeros ( ) {
324
+ //'csr' | 'csc' | undefined
325
+ getNonZeros ( format ) {
263
326
const cardinality = this . cardinality ;
264
- const rows = new Array ( cardinality ) ;
265
- const columns = new Array ( cardinality ) ;
266
- const values = new Array ( cardinality ) ;
327
+ const rows = new Float64Array ( cardinality ) ;
328
+ const columns = new Float64Array ( cardinality ) ;
329
+ const values = new Float64Array ( cardinality ) ;
267
330
let idx = 0 ;
268
331
this . withEachNonZero ( ( i , j , value ) => {
269
332
rows [ idx ] = i ;
270
333
columns [ idx ] = j ;
271
334
values [ idx ] = value ;
272
335
idx ++ ;
273
336
} ) ;
274
- return { rows, columns, values } ;
337
+
338
+ const cooMatrix = { rows, columns, values } ;
339
+
340
+ if ( ! format ) return cooMatrix ;
341
+
342
+ const csrMatrix = cooToCsr ( cooMatrix , this . rows ) ;
343
+ return format === 'csc' ? csrToCsc ( csrMatrix , this . columns ) : csrMatrix ;
275
344
}
276
345
277
346
setThreshold ( newThreshold ) {
@@ -452,3 +521,57 @@ function fillTemplateFunction(template, values) {
452
521
}
453
522
return template ;
454
523
}
524
+
525
+ function csrToCsc ( csrMatrix , numCols ) {
526
+ const {
527
+ values : csrValues ,
528
+ columns : csrColIndices ,
529
+ rows : csrRowPtr ,
530
+ } = csrMatrix ;
531
+ // Initialize CSC arrays
532
+ const cscValues = new Float64Array ( csrValues . length ) ;
533
+ const cscRowIndices = new Float64Array ( csrValues . length ) ;
534
+ const cscColPtr = new Float64Array ( numCols + 1 ) ;
535
+
536
+ // Count non-zeros per column
537
+ for ( let i = 0 ; i < csrColIndices . length ; i ++ ) {
538
+ cscColPtr [ csrColIndices [ i ] + 1 ] ++ ;
539
+ }
540
+
541
+ // Compute column pointers (prefix sum)
542
+ for ( let i = 1 ; i <= numCols ; i ++ ) {
543
+ cscColPtr [ i ] += cscColPtr [ i - 1 ] ;
544
+ }
545
+
546
+ // Temporary copy for filling values
547
+ const next = cscColPtr . slice ( ) ;
548
+
549
+ // Fill CSC arrays
550
+ for ( let row = 0 ; row < csrRowPtr . length - 1 ; row ++ ) {
551
+ for ( let j = csrRowPtr [ row ] , i = 0 ; j < csrRowPtr [ row + 1 ] ; j ++ , i ++ ) {
552
+ const col = csrColIndices [ j ] ;
553
+ const pos = next [ col ] ;
554
+ cscValues [ pos ] = csrValues [ j ] ;
555
+ cscRowIndices [ pos ] = row ;
556
+ next [ col ] ++ ;
557
+ }
558
+ // if (row === 1) break;
559
+ }
560
+
561
+ return { rows : cscRowIndices , columns : cscColPtr , values : cscValues } ;
562
+ }
563
+
564
+ function cooToCsr ( cooMatrix , nbRows = 9 ) {
565
+ const { values, columns, rows } = cooMatrix ;
566
+ //could not be the same length
567
+ const csrRowPtr = new Float64Array ( nbRows + 1 ) ;
568
+ const length = values . length ;
569
+ let currentRow = rows [ 0 ] ;
570
+ for ( let index = 0 ; index < length ; ) {
571
+ const prev = index ;
572
+ while ( currentRow === rows [ index ] && index < length ) ++ index ;
573
+ csrRowPtr [ currentRow + 1 ] = index ;
574
+ currentRow += 1 ;
575
+ }
576
+ return { rows : csrRowPtr , columns, values } ;
577
+ }
0 commit comments