Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 119 additions & 20 deletions src/matrix_coo/matrix_coo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,13 @@ function Base.:+(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO)
# Mark unique entries (first occurrence of each (row, col) pair)
keep_mask = similar(rowind_sorted, Bool, nnz_concat)
kernel_mark! = kernel_mark_unique_coo!(backend)
kernel_mark!(keep_mask, rowind_sorted, colind_sorted, nnz_concat; ndrange = (nnz_concat,))
kernel_mark!(
keep_mask,
rowind_sorted,
colind_sorted,
nnz_concat;
ndrange = (nnz_concat,),
)

# Compute write indices using cumsum
write_indices = _cumsum_AK(keep_mask)
Expand Down Expand Up @@ -415,42 +421,43 @@ end

# Addition with transpose/adjoint support
for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparseMatrixCOO)
for (wrapb, transb, conjb, unwrapb, whereT2) in trans_adj_wrappers(:DeviceSparseMatrixCOO)
for (wrapb, transb, conjb, unwrapb, whereT2) in
trans_adj_wrappers(:DeviceSparseMatrixCOO)
# Skip the case where both are not transposed (already handled above)
(transa == false && transb == false) && continue

TypeA = wrapa(:(T1))
TypeB = wrapb(:(T2))

@eval function Base.:+(A::$TypeA, B::$TypeB) where {$(whereT1(:T1)),$(whereT2(:T2))}
size(A) == size(B) || throw(
DimensionMismatch(
"dimensions must match: A has dims $(size(A)), B has dims $(size(B))",
),
)

_A = $(unwrapa(:A))
_B = $(unwrapb(:B))

backend_A = get_backend(_A)
backend_B = get_backend(_B)
backend_A == backend_B ||
throw(ArgumentError("Both matrices must have the same backend"))

m, n = size(A)
Ti = eltype(getrowind(_A))
Tv = promote_type(eltype(nonzeros(_A)), eltype(nonzeros(_B)))

# For transposed COO, swap row and column indices
nnz_A = nnz(_A)
nnz_B = nnz(_B)
nnz_concat = nnz_A + nnz_B

# Allocate concatenated arrays
rowind_concat = similar(getrowind(_A), nnz_concat)
colind_concat = similar(getcolind(_A), nnz_concat)
nzval_concat = similar(nonzeros(_A), Tv, nnz_concat)

# Copy entries from A (potentially swapping row/col for transpose)
if $transa
rowind_concat[1:nnz_A] .= getcolind(_A) # Swap for transpose
Expand All @@ -464,7 +471,7 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse
else
nzval_concat[1:nnz_A] .= nonzeros(_A)
end

# Copy entries from B (potentially swapping row/col for transpose)
if $transb
rowind_concat[(nnz_A+1):end] .= getcolind(_B) # Swap for transpose
Expand All @@ -478,29 +485,41 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse
else
nzval_concat[(nnz_A+1):end] .= nonzeros(_B)
end

# Sort and compact (same as before)
backend = backend_A
keys = similar(rowind_concat, Ti, nnz_concat)
kernel_make_keys! = kernel_make_csc_keys!(backend)
kernel_make_keys!(keys, rowind_concat, colind_concat, m; ndrange = (nnz_concat,))

kernel_make_keys!(
keys,
rowind_concat,
colind_concat,
m;
ndrange = (nnz_concat,),
)

perm = _sortperm_AK(keys)
rowind_sorted = rowind_concat[perm]
colind_sorted = colind_concat[perm]
nzval_sorted = nzval_concat[perm]

keep_mask = similar(rowind_sorted, Bool, nnz_concat)
kernel_mark! = kernel_mark_unique_coo!(backend)
kernel_mark!(keep_mask, rowind_sorted, colind_sorted, nnz_concat; ndrange = (nnz_concat,))

kernel_mark!(
keep_mask,
rowind_sorted,
colind_sorted,
nnz_concat;
ndrange = (nnz_concat,),
)

write_indices = _cumsum_AK(keep_mask)
nnz_final = @allowscalar write_indices[nnz_concat]

rowind_C = similar(getrowind(_A), nnz_final)
colind_C = similar(getcolind(_A), nnz_final)
nzval_C = similar(nonzeros(_A), Tv, nnz_final)

kernel_compact! = kernel_compact_coo!(backend)
kernel_compact!(
rowind_C,
Expand All @@ -513,7 +532,7 @@ for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparse
nnz_concat;
ndrange = (nnz_concat,),
)

return DeviceSparseMatrixCOO(m, n, rowind_C, colind_C, nzval_C)
end
end
Expand Down Expand Up @@ -587,3 +606,83 @@ function LinearAlgebra.kron(

return DeviceSparseMatrixCOO(m_C, n_C, rowind_C, colind_C, nzval_C)
end

"""
*(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO)

Multiply two sparse matrices in COO format. Both matrices must have compatible dimensions
(number of columns of A equals number of rows of B) and be on the same backend (device).

The multiplication converts to CSC format, performs the multiplication with GPU-compatible
kernels, and converts back to COO format.

# Examples
```jldoctest
julia> using DeviceSparseArrays, SparseArrays

julia> A = DeviceSparseMatrixCOO(sparse([1, 2], [1, 2], [2.0, 3.0], 2, 2));

julia> B = DeviceSparseMatrixCOO(sparse([1, 2], [1, 2], [4.0, 5.0], 2, 2));

julia> C = A * B;

julia> collect(C)
2×2 Matrix{Float64}:
8.0 0.0
0.0 15.0
```
"""
function Base.:(*)(A::DeviceSparseMatrixCOO, B::DeviceSparseMatrixCOO)
size(A, 2) == size(B, 1) || throw(
DimensionMismatch(
"second dimension of A, $(size(A,2)), does not match first dimension of B, $(size(B,1))",
),
)

backend_A = get_backend(A)
backend_B = get_backend(B)
backend_A == backend_B ||
throw(ArgumentError("Both matrices must have the same backend"))

# Convert to CSC, multiply, convert back to COO
# This is acceptable as COO doesn't have an efficient direct multiplication algorithm
# and CSC provides the sorted structure needed for efficient SpGEMM
A_csc = DeviceSparseMatrixCSC(A)
B_csc = DeviceSparseMatrixCSC(B)
C_csc = A_csc * B_csc
return DeviceSparseMatrixCOO(C_csc)
end

# Multiplication with transpose/adjoint support
for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:DeviceSparseMatrixCOO)
for (wrapb, transb, conjb, unwrapb, whereT2) in
trans_adj_wrappers(:DeviceSparseMatrixCOO)
# Skip the case where both are not transposed (already handled above)
(transa == false && transb == false) && continue

TypeA = wrapa(:(T1))
TypeB = wrapb(:(T2))

@eval function Base.:(*)(
A::$TypeA,
B::$TypeB,
) where {$(whereT1(:T1)),$(whereT2(:T2))}
size(A, 2) == size(B, 1) || throw(
DimensionMismatch(
"second dimension of A, $(size(A,2)), does not match first dimension of B, $(size(B,1))",
),
)

backend_A = get_backend($(unwrapa(:A)))
backend_B = get_backend($(unwrapb(:B)))
backend_A == backend_B ||
throw(ArgumentError("Both matrices must have the same backend"))

# Convert to CSC (handles transpose/adjoint), multiply, convert back to COO
A_csc = DeviceSparseMatrixCSC(A)
B_csc = DeviceSparseMatrixCSC(B)
C_csc = A_csc * B_csc
return DeviceSparseMatrixCOO(C_csc)
end
end
end
8 changes: 5 additions & 3 deletions src/matrix_coo/matrix_coo_kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,18 @@ end

if i <= nnz_in
out_idx = write_indices[i]

# If this is a new entry (or first of duplicates), write it
if i == 1 || (rowind_in[i] != rowind_in[i-1] || colind_in[i] != colind_in[i-1])
rowind_out[out_idx] = rowind_in[i]
colind_out[out_idx] = colind_in[i]

# Sum all duplicates
val_sum = nzval_in[i]
j = i + 1
while j <= nnz_in && rowind_in[j] == rowind_in[i] && colind_in[j] == colind_in[i]
while j <= nnz_in &&
rowind_in[j] == rowind_in[i] &&
colind_in[j] == colind_in[i]
val_sum += nzval_in[j]
j += 1
end
Expand Down
Loading
Loading