Skip to content

Commit 78a52e6

Browse files
committed
Initial commit
1 parent 87b245f commit 78a52e6

File tree

4 files changed

+110
-2
lines changed

4 files changed

+110
-2
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ uuid = "59d54670-b8ac-4d81-ab7a-bb56233e17ab"
33
authors = ["Stefanos Carlström <[email protected]>"]
44
version = "0.1.0"
55

6+
[deps]
7+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
9+
610
[compat]
711
julia = "1"
812

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,7 @@
33
[![Build Status](https://travis-ci.com/jagot/ThreadedSparseArrays.jl.svg?branch=master)](https://travis-ci.com/jagot/ThreadedSparseArrays.jl)
44
[![Build Status](https://ci.appveyor.com/api/projects/status/github/jagot/ThreadedSparseArrays.jl?svg=true)](https://ci.appveyor.com/project/jagot/ThreadedSparseArrays-jl)
55
[![Codecov](https://codecov.io/gh/jagot/ThreadedSparseArrays.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/jagot/ThreadedSparseArrays.jl)
6+
7+
Simple package providing a wrapper type enabling threaded sparse
8+
matrix–dense matrix multiplication. Based on [this
9+
PR](https://github.com/JuliaLang/julia/pull/29525).

src/ThreadedSparseArrays.jl

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,92 @@
11
module ThreadedSparseArrays
22

3-
greet() = print("Hello World!")
3+
using LinearAlgebra
4+
import LinearAlgebra: mul!
5+
using SparseArrays
6+
import SparseArrays: AdjOrTransStridedOrTriangularMatrix, getcolptr
7+
8+
struct ThreadedSparseMatrixCSC{Tv,Ti,At} <: AbstractSparseMatrix{Tv,Ti}
9+
A::At
10+
ThreadedSparseMatrixCSC(A::At) where {Tv,Ti,At<:AbstractSparseMatrix{Tv,Ti}} =
11+
new{Tv,Ti,At}(A)
12+
end
13+
14+
Base.size(A::ThreadedSparseMatrixCSC, args...) = size(A.A, args...)
15+
Base.eltype(A::ThreadedSparseMatrixCSC) = eltype(A.A)
16+
Base.getindex(A::ThreadedSparseMatrixCSC, args...) = getindex(A.A, args...)
17+
18+
# Need to override printing
19+
# Need to forward findnz, etc
20+
21+
for f in [:rowvals, :nonzeros, :getcolptr]
22+
@eval SparseArrays.$(f)(A::ThreadedSparseMatrixCSC) = SparseArrays.$(f)(A.A)
23+
end
24+
25+
function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:ThreadedSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
26+
A = adjA.parent
27+
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
28+
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
29+
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
30+
colptrA = getcolptr(A)
31+
nzv = nonzeros(A)
32+
rv = rowvals(A)
33+
if β != 1
34+
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
35+
end
36+
for k = 1:size(C, 2)
37+
Threads.@threads for col = 1:size(A, 2)
38+
@inbounds begin
39+
tmp = zero(eltype(C))
40+
for j = colptrA[col]:(colptrA[col+1] - 1)
41+
tmp += adjoint(nzv[j])*B[rv[j],k]
42+
end
43+
C[col,k] += α*tmp
44+
end
45+
end
46+
end
47+
C
48+
end
49+
50+
function mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:ThreadedSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
51+
A = transA.parent
52+
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
53+
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
54+
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
55+
nzv = nonzeros(A)
56+
rv = rowvals(A)
57+
if β != 1
58+
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
59+
end
60+
Threads.@threads for k = 1:size(C, 2)
61+
@inbounds for col = 1:size(A, 2)
62+
tmp = zero(eltype(C))
63+
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
64+
tmp += transpose(nzv[j])*B[rv[j],k]
65+
end
66+
C[col,k] += tmp * α
67+
end
68+
end
69+
C
70+
end
71+
72+
function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, A::ThreadedSparseMatrixCSC, α::Number, β::Number)
73+
mX, nX = size(X)
74+
nX == size(A, 1) || throw(DimensionMismatch())
75+
mX == size(C, 1) || throw(DimensionMismatch())
76+
size(A, 2) == size(C, 2) || throw(DimensionMismatch())
77+
rv = rowvals(A)
78+
nzv = nonzeros(A)
79+
if β != 1
80+
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
81+
end
82+
Threads.@threads for col = 1:size(A, 2)
83+
@inbounds for multivec_row=1:mX, k=getcolptr(A)[col]:(getcolptr(A)[col+1]-1)
84+
C[multivec_row, col] += α * X[multivec_row, rv[k]] * nzv[k] # perhaps suboptimal position of α?
85+
end
86+
end
87+
C
88+
end
89+
90+
export ThreadedSparseMatrixCSC
491

592
end # module

test/runtests.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
11
using ThreadedSparseArrays
2+
using SparseArrays
3+
using LinearAlgebra
24
using Test
35

46
@testset "ThreadedSparseArrays.jl" begin
5-
# Write your own tests here.
7+
M,N = 5000,4000
8+
n = 200
9+
T = ComplexF64
10+
11+
C = sprand(T, N, n, 0.05)
12+
Ct = ThreadedSparseMatrixCSC(C)
13+
14+
eye = Matrix(one(T)*I, N, N)
15+
out = zeros(T, N, n)
16+
LinearAlgebra.mul!(out, eye, Ct)
17+
ref = eye*C
18+
@test norm(ref-out) == 0
619
end

0 commit comments

Comments
 (0)