Skip to content

Commit 4ebee38

Browse files
Add comprehensive tests for sparse-sparse multiplication
Co-authored-by: albertomercurio <[email protected]>
1 parent 489f945 commit 4ebee38

File tree

7 files changed

+243
-73
lines changed

7 files changed

+243
-73
lines changed

src/matrix_coo/matrix_coo.jl

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,13 @@ function Base.:+(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO)
385385
# Mark unique entries (first occurrence of each (row, col) pair)
386386
keep_mask = similar(rowind_sorted, Bool, nnz_concat)
387387
kernel_mark! = kernel_mark_unique_coo!(backend)
388-
kernel_mark!(keep_mask, rowind_sorted, colind_sorted, nnz_concat; ndrange = (nnz_concat,))
388+
kernel_mark!(
389+
keep_mask,
390+
rowind_sorted,
391+
colind_sorted,
392+
nnz_concat;
393+
ndrange = (nnz_concat,),
394+
)
389395

390396
# Compute write indices using cumsum
391397
write_indices = _cumsum_AK(keep_mask)
@@ -415,42 +421,43 @@ end
415421

416422
# Addition with transpose/adjoint support
417423
for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparseMatrixCOO)
418-
for (wrapb, transb, conjb, unwrapb, whereT2) in trans_adj_wrappers(:DeviceSparseMatrixCOO)
424+
for (wrapb, transb, conjb, unwrapb, whereT2) in
425+
trans_adj_wrappers(:DeviceSparseMatrixCOO)
419426
# Skip the case where both are not transposed (already handled above)
420427
(transa == false && transb == false) && continue
421-
428+
422429
TypeA = wrapa(:(T1))
423430
TypeB = wrapb(:(T2))
424-
431+
425432
@eval function Base.:+(A::$TypeA, B::$TypeB) where {$(whereT1(:T1)),$(whereT2(:T2))}
426433
size(A) == size(B) || throw(
427434
DimensionMismatch(
428435
"dimensions must match: A has dims $(size(A)), B has dims $(size(B))",
429436
),
430437
)
431-
438+
432439
_A = $(unwrapa(:A))
433440
_B = $(unwrapb(:B))
434-
441+
435442
backend_A = get_backend(_A)
436443
backend_B = get_backend(_B)
437444
backend_A == backend_B ||
438445
throw(ArgumentError("Both matrices must have the same backend"))
439-
446+
440447
m, n = size(A)
441448
Ti = eltype(getrowind(_A))
442449
Tv = promote_type(eltype(nonzeros(_A)), eltype(nonzeros(_B)))
443-
450+
444451
# For transposed COO, swap row and column indices
445452
nnz_A = nnz(_A)
446453
nnz_B = nnz(_B)
447454
nnz_concat = nnz_A + nnz_B
448-
455+
449456
# Allocate concatenated arrays
450457
rowind_concat = similar(getrowind(_A), nnz_concat)
451458
colind_concat = similar(getcolind(_A), nnz_concat)
452459
nzval_concat = similar(nonzeros(_A), Tv, nnz_concat)
453-
460+
454461
# Copy entries from A (potentially swapping row/col for transpose)
455462
if $transa
456463
rowind_concat[1:nnz_A] .= getcolind(_A) # Swap for transpose
@@ -464,7 +471,7 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse
464471
else
465472
nzval_concat[1:nnz_A] .= nonzeros(_A)
466473
end
467-
474+
468475
# Copy entries from B (potentially swapping row/col for transpose)
469476
if $transb
470477
rowind_concat[(nnz_A+1):end] .= getcolind(_B) # Swap for transpose
@@ -478,29 +485,41 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse
478485
else
479486
nzval_concat[(nnz_A+1):end] .= nonzeros(_B)
480487
end
481-
488+
482489
# Sort and compact (same as before)
483490
backend = backend_A
484491
keys = similar(rowind_concat, Ti, nnz_concat)
485492
kernel_make_keys! = kernel_make_csc_keys!(backend)
486-
kernel_make_keys!(keys, rowind_concat, colind_concat, m; ndrange = (nnz_concat,))
487-
493+
kernel_make_keys!(
494+
keys,
495+
rowind_concat,
496+
colind_concat,
497+
m;
498+
ndrange = (nnz_concat,),
499+
)
500+
488501
perm = _sortperm_AK(keys)
489502
rowind_sorted = rowind_concat[perm]
490503
colind_sorted = colind_concat[perm]
491504
nzval_sorted = nzval_concat[perm]
492-
505+
493506
keep_mask = similar(rowind_sorted, Bool, nnz_concat)
494507
kernel_mark! = kernel_mark_unique_coo!(backend)
495-
kernel_mark!(keep_mask, rowind_sorted, colind_sorted, nnz_concat; ndrange = (nnz_concat,))
496-
508+
kernel_mark!(
509+
keep_mask,
510+
rowind_sorted,
511+
colind_sorted,
512+
nnz_concat;
513+
ndrange = (nnz_concat,),
514+
)
515+
497516
write_indices = _cumsum_AK(keep_mask)
498517
nnz_final = @allowscalar write_indices[nnz_concat]
499-
518+
500519
rowind_C = similar(getrowind(_A), nnz_final)
501520
colind_C = similar(getcolind(_A), nnz_final)
502521
nzval_C = similar(nonzeros(_A), Tv, nnz_final)
503-
522+
504523
kernel_compact! = kernel_compact_coo!(backend)
505524
kernel_compact!(
506525
rowind_C,
@@ -513,7 +532,7 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse
513532
nnz_concat;
514533
ndrange = (nnz_concat,),
515534
)
516-
535+
517536
return DeviceSparseMatrixCOO(m, n, rowind_C, colind_C, nzval_C)
518537
end
519538
end
@@ -630,38 +649,42 @@ function Base.:(*)(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO)
630649
B_sparse = SparseMatrixCSC(B)
631650
C_sparse = A_sparse * B_sparse
632651
C = DeviceSparseMatrixCOO(C_sparse)
633-
652+
634653
# Adapt to the same backend as A and B
635654
return Adapt.adapt(backend_A, C)
636655
end
637656

638657
# Multiplication with transpose/adjoint support
639658
for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparseMatrixCOO)
640-
for (wrapb, transb, conjb, unwrapb, whereT2) in trans_adj_wrappers(:DeviceSparseMatrixCOO)
659+
for (wrapb, transb, conjb, unwrapb, whereT2) in
660+
trans_adj_wrappers(:DeviceSparseMatrixCOO)
641661
# Skip the case where both are not transposed (already handled above)
642662
(transa == false && transb == false) && continue
643-
663+
644664
TypeA = wrapa(:(T1))
645665
TypeB = wrapb(:(T2))
646-
647-
@eval function Base.:(*)(A::$TypeA, B::$TypeB) where {$(whereT1(:T1)),$(whereT2(:T2))}
666+
667+
@eval function Base.:(*)(
668+
A::$TypeA,
669+
B::$TypeB,
670+
) where {$(whereT1(:T1)),$(whereT2(:T2))}
648671
size(A, 2) == size(B, 1) || throw(
649672
DimensionMismatch(
650673
"second dimension of A, $(size(A,2)), does not match first dimension of B, $(size(B,1))",
651674
),
652675
)
653-
676+
654677
backend_A = get_backend($(unwrapa(:A)))
655678
backend_B = get_backend($(unwrapb(:B)))
656679
backend_A == backend_B ||
657680
throw(ArgumentError("Both matrices must have the same backend"))
658-
681+
659682
# Convert to SparseMatrixCSC (handles transpose/adjoint), multiply, convert back
660683
A_sparse = SparseMatrixCSC(A)
661684
B_sparse = SparseMatrixCSC(B)
662685
C_sparse = A_sparse * B_sparse
663686
C = DeviceSparseMatrixCOO(C_sparse)
664-
687+
665688
# Adapt to the same backend as A and B
666689
return Adapt.adapt(backend_A, C)
667690
end

src/matrix_coo/matrix_coo_kernels.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,16 +216,18 @@ end
216216

217217
if i <= nnz_in
218218
out_idx = write_indices[i]
219-
219+
220220
# If this is a new entry (or first of duplicates), write it
221221
if i == 1 || (rowind_in[i] != rowind_in[i-1] || colind_in[i] != colind_in[i-1])
222222
rowind_out[out_idx] = rowind_in[i]
223223
colind_out[out_idx] = colind_in[i]
224-
224+
225225
# Sum all duplicates
226226
val_sum = nzval_in[i]
227227
j = i + 1
228-
while j <= nnz_in && rowind_in[j] == rowind_in[i] && colind_in[j] == colind_in[i]
228+
while j <= nnz_in &&
229+
rowind_in[j] == rowind_in[i] &&
230+
colind_in[j] == colind_in[i]
229231
val_sum += nzval_in[j]
230232
j += 1
231233
end

src/matrix_csc/matrix_csc.jl

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ function Base.:+(A::DeviceSparseMatrixCSC, B::DeviceSparseMatrixCSC)
365365
colptr_C[1:1] .= one(Ti)
366366

367367
# Allocate result arrays
368-
nnz_total = @allowscalar colptr_C[n+1] - one(Ti)
368+
nnz_total = @allowscalar colptr_C[n+1] - one(Ti)
369369
rowval_C = similar(getrowval(A), nnz_total)
370370
nzval_C = similar(nonzeros(A), Tv, nnz_total)
371371

@@ -391,27 +391,28 @@ end
391391

392392
# Addition with transpose/adjoint support
393393
for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparseMatrixCSC)
394-
for (wrapb, transb, conjb, unwrapb, whereT2) in trans_adj_wrappers(:DeviceSparseMatrixCSC)
394+
for (wrapb, transb, conjb, unwrapb, whereT2) in
395+
trans_adj_wrappers(:DeviceSparseMatrixCSC)
395396
# Skip the case where both are not transposed (already handled above)
396397
(transa == false && transb == false) && continue
397-
398+
398399
TypeA = wrapa(:(T1))
399400
TypeB = wrapb(:(T2))
400-
401+
401402
@eval function Base.:+(A::$TypeA, B::$TypeB) where {$(whereT1(:T1)),$(whereT2(:T2))}
402403
size(A) == size(B) || throw(
403404
DimensionMismatch(
404405
"dimensions must match: A has dims $(size(A)), B has dims $(size(B))",
405406
),
406407
)
407-
408+
408409
# Convert both to CSR (transpose/adjoint of CSC has CSR structure)
409410
# and use existing CSR + CSR addition. The conversion methods
410411
# already handle transpose/adjoint correctly.
411412
A_csr = DeviceSparseMatrixCSR(A)
412413
B_csr = DeviceSparseMatrixCSR(B)
413414
result_csr = A_csr + B_csr
414-
415+
415416
# Convert back to CSC
416417
return DeviceSparseMatrixCSC(result_csr)
417418
end
@@ -493,38 +494,42 @@ function Base.:(*)(A::DeviceSparseMatrixCSC, B::DeviceSparseMatrixCSC)
493494
B_sparse = SparseMatrixCSC(B)
494495
C_sparse = A_sparse * B_sparse
495496
C = DeviceSparseMatrixCSC(C_sparse)
496-
497+
497498
# Adapt to the same backend as A and B
498499
return Adapt.adapt(backend_A, C)
499500
end
500501

501502
# Multiplication with transpose/adjoint support
502503
for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparseMatrixCSC)
503-
for (wrapb, transb, conjb, unwrapb, whereT2) in trans_adj_wrappers(:DeviceSparseMatrixCSC)
504+
for (wrapb, transb, conjb, unwrapb, whereT2) in
505+
trans_adj_wrappers(:DeviceSparseMatrixCSC)
504506
# Skip the case where both are not transposed (already handled above)
505507
(transa == false && transb == false) && continue
506-
508+
507509
TypeA = wrapa(:(T1))
508510
TypeB = wrapb(:(T2))
509-
510-
@eval function Base.:(*)(A::$TypeA, B::$TypeB) where {$(whereT1(:T1)),$(whereT2(:T2))}
511+
512+
@eval function Base.:(*)(
513+
A::$TypeA,
514+
B::$TypeB,
515+
) where {$(whereT1(:T1)),$(whereT2(:T2))}
511516
size(A, 2) == size(B, 1) || throw(
512517
DimensionMismatch(
513518
"second dimension of A, $(size(A,2)), does not match first dimension of B, $(size(B,1))",
514519
),
515520
)
516-
521+
517522
backend_A = get_backend($(unwrapa(:A)))
518523
backend_B = get_backend($(unwrapb(:B)))
519524
backend_A == backend_B ||
520525
throw(ArgumentError("Both matrices must have the same backend"))
521-
526+
522527
# Convert to SparseMatrixCSC (handles transpose/adjoint), multiply, convert back
523528
A_sparse = SparseMatrixCSC(A)
524529
B_sparse = SparseMatrixCSC(B)
525530
C_sparse = A_sparse * B_sparse
526531
C = DeviceSparseMatrixCSC(C_sparse)
527-
532+
528533
# Adapt to the same backend as A and B
529534
return Adapt.adapt(backend_A, C)
530535
end

src/matrix_csr/matrix_csr.jl

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -389,27 +389,28 @@ end
389389

390390
# Addition with transpose/adjoint support
391391
for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparseMatrixCSR)
392-
for (wrapb, transb, conjb, unwrapb, whereT2) in trans_adj_wrappers(:DeviceSparseMatrixCSR)
392+
for (wrapb, transb, conjb, unwrapb, whereT2) in
393+
trans_adj_wrappers(:DeviceSparseMatrixCSR)
393394
# Skip the case where both are not transposed (already handled above)
394395
(transa == false && transb == false) && continue
395-
396+
396397
TypeA = wrapa(:(T1))
397398
TypeB = wrapb(:(T2))
398-
399+
399400
@eval function Base.:+(A::$TypeA, B::$TypeB) where {$(whereT1(:T1)),$(whereT2(:T2))}
400401
size(A) == size(B) || throw(
401402
DimensionMismatch(
402403
"dimensions must match: A has dims $(size(A)), B has dims $(size(B))",
403404
),
404405
)
405-
406+
406407
# Convert both to CSC (transpose/adjoint of CSR has CSC structure)
407408
# and use existing CSC + CSC addition. The conversion methods
408409
# already handle transpose/adjoint correctly.
409410
A_csc = DeviceSparseMatrixCSC(A)
410411
B_csc = DeviceSparseMatrixCSC(B)
411412
result_csc = A_csc + B_csc
412-
413+
413414
# Convert back to CSR
414415
return DeviceSparseMatrixCSR(result_csc)
415416
end
@@ -499,38 +500,42 @@ function Base.:(*)(A::DeviceSparseMatrixCSR, B::DeviceSparseMatrixCSR)
499500
B_sparse = SparseMatrixCSC(B)
500501
C_sparse = A_sparse * B_sparse
501502
C = DeviceSparseMatrixCSR(C_sparse)
502-
503+
503504
# Adapt to the same backend as A and B
504505
return Adapt.adapt(backend_A, C)
505506
end
506507

507508
# Multiplication with transpose/adjoint support
508509
for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparseMatrixCSR)
509-
for (wrapb, transb, conjb, unwrapb, whereT2) in trans_adj_wrappers(:DeviceSparseMatrixCSR)
510+
for (wrapb, transb, conjb, unwrapb, whereT2) in
511+
trans_adj_wrappers(:DeviceSparseMatrixCSR)
510512
# Skip the case where both are not transposed (already handled above)
511513
(transa == false && transb == false) && continue
512-
514+
513515
TypeA = wrapa(:(T1))
514516
TypeB = wrapb(:(T2))
515-
516-
@eval function Base.:(*)(A::$TypeA, B::$TypeB) where {$(whereT1(:T1)),$(whereT2(:T2))}
517+
518+
@eval function Base.:(*)(
519+
A::$TypeA,
520+
B::$TypeB,
521+
) where {$(whereT1(:T1)),$(whereT2(:T2))}
517522
size(A, 2) == size(B, 1) || throw(
518523
DimensionMismatch(
519524
"second dimension of A, $(size(A,2)), does not match first dimension of B, $(size(B,1))",
520525
),
521526
)
522-
527+
523528
backend_A = get_backend($(unwrapa(:A)))
524529
backend_B = get_backend($(unwrapb(:B)))
525530
backend_A == backend_B ||
526531
throw(ArgumentError("Both matrices must have the same backend"))
527-
532+
528533
# Convert to SparseMatrixCSC (handles transpose/adjoint), multiply, convert back
529534
A_sparse = SparseMatrixCSC(A)
530535
B_sparse = SparseMatrixCSC(B)
531536
C_sparse = A_sparse * B_sparse
532537
C = DeviceSparseMatrixCSR(C_sparse)
533-
538+
534539
# Adapt to the same backend as A and B
535540
return Adapt.adapt(backend_A, C)
536541
end

0 commit comments

Comments
 (0)