Skip to content

Commit b6a393f

Browse files
authored
Improve the dispatch for sparse routines (#410)
1 parent 949a457 commit b6a393f

File tree

6 files changed

+93
-16
lines changed

6 files changed

+93
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ SpecialFunctions = "1.3, 2"
4141
StaticArrays = "1"
4242
julia = "1.8"
4343
oneAPI_Level_Zero_Loader_jll = "1.9"
44-
oneAPI_Support_jll = "~0.3.2"
44+
oneAPI_Support_jll = "~0.3.3"
4545

4646
[extras]
4747
libigc_jll = "94295238-5935-5bd7-bb0f-b00942e9bdd5"

lib/mkl/array.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
export oneSparseMatrixCSR
2+
3+
abstract type oneAbstractSparseArray{Tv, Ti, N} <: AbstractSparseArray{Tv, Ti, N} end
4+
const oneAbstractSparseVector{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 1}
5+
const oneAbstractSparseMatrix{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 2}
6+
7+
mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
8+
handle::matrix_handle_t
9+
rowPtr::oneVector{Ti}
10+
colVal::oneVector{Ti}
11+
nzVal::oneVector{Tv}
12+
dims::NTuple{2,Int}
13+
nnz::Ti
14+
end
15+
16+
Base.length(A::oneSparseMatrixCSR) = prod(A.dims)
17+
Base.size(A::oneSparseMatrixCSR) = A.dims
18+
19+
function Base.size(A::oneSparseMatrixCSR, d::Integer)
20+
if d == 1 || d == 2
21+
return A.dims[d]
22+
else
23+
throw(ArgumentError("dimension must be 1 or 2, got $d"))
24+
end
25+
end
26+
27+
SparseArrays.nnz(A::oneSparseMatrixCSR) = A.nnz
28+
SparseArrays.nonzeros(A::oneSparseMatrixCSR) = A.nzVal
29+
30+
for (gpu, cpu) in [:oneSparseMatrixCSR => :SparseMatrixCSC]
31+
@eval Base.show(io::IOContext, x::$gpu) =
32+
show(io, $cpu(x))
33+
34+
@eval function Base.show(io::IO, mime::MIME"text/plain", S::$gpu)
35+
xnnz = nnz(S)
36+
m, n = size(S)
37+
print(io, m, "×", n, " ", typeof(S), " with ", xnnz, " stored ",
38+
xnnz == 1 ? "entry" : "entries")
39+
if !(m == 0 || n == 0)
40+
println(io, ":")
41+
io = IOContext(io, :typeinfo => eltype(S))
42+
if ndims(S) == 1
43+
show(io, $cpu(S))
44+
else
45+
# so that we get the nice Braille pattern
46+
Base.print_array(io, $cpu(S))
47+
end
48+
end
49+
end
50+
end

lib/mkl/interfaces.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# interfacing with other packages
2+
3+
using LinearAlgebra: BlasComplex, BlasFloat, BlasReal, MulAddMul
4+
5+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSR{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasFloat
6+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
7+
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
8+
end
9+
10+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat
11+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
12+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
13+
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
14+
end
15+
16+
if VERSION v"1.10-"
17+
for SparseMatrixType in (:oneSparseMatrixCSR,)
18+
@eval begin
19+
function LinearAlgebra.generic_trimatdiv!(C::oneVector{T}, uploc, isunitc, tfun::Function, A::$SparseMatrixType{T}, B::oneVector{T}) where T <: BlasFloat
20+
sparse_trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
21+
end
22+
end
23+
end
24+
end

lib/mkl/oneMKL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ const onemklFloat = Union{Float64,Float32,ComplexF64,ComplexF32}
2222
const onemklComplex = Union{ComplexF32,ComplexF64}
2323
const onemklHalf = Union{Float16,ComplexF16}
2424

25+
include("array.jl")
2526
include("utils.jl")
2627
include("wrappers_blas.jl")
2728
include("wrappers_lapack.jl")
2829
include("wrappers_sparse.jl")
2930
include("linalg.jl")
31+
include("interfaces.jl")
3032

3133
function band(A::StridedArray, kl, ku)
3234
m, n = size(A)

lib/mkl/wrappers_sparse.jl

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,3 @@
1-
export oneSparseMatrixCSR
2-
3-
mutable struct oneSparseMatrixCSR{T}
4-
handle::matrix_handle_t
5-
type::Type{T}
6-
m::Int
7-
n::Int
8-
end
9-
101
for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int32),
112
(:onemklSsparse_set_csr_data_64, :Float32 , :Int64),
123
(:onemklDsparse_set_csr_data , :Float64 , :Int32),
@@ -21,12 +12,20 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
2112
onemklXsparse_init_matrix_handle(handle_ptr)
2213
m, n = size(A)
2314
At = SparseMatrixCSC(A |> transpose)
24-
row_ptr = oneVector{$intty}(At.colptr)
25-
col_ind = oneVector{$intty}(At.rowval)
26-
val = oneVector{$elty}(At.nzval)
27-
queue = global_queue(context(val), device(val))
28-
$fname(sycl_queue(queue), handle_ptr[], m, n, 'O', row_ptr, col_ind, val)
29-
return oneSparseMatrixCSR{$elty}(handle_ptr[], $elty, m, n)
15+
rowPtr = oneVector{$intty}(At.colptr)
16+
colVal = oneVector{$intty}(At.rowval)
17+
nzVal = oneVector{$elty}(At.nzval)
18+
nnzA = length(At.nzval)
19+
queue = global_queue(context(nzVal), device(nzVal))
20+
$fname(sycl_queue(queue), handle_ptr[], m, n, 'O', rowPtr, colVal, nzVal)
21+
return oneSparseMatrixCSR{$elty, $intty}(handle_ptr[], rowPtr, colVal, nzVal, (m,n), nnzA)
22+
end
23+
24+
function SparseMatrixCSC(A::oneSparseMatrixCSR{$elty, $intty})
25+
handle_ptr = Ref{matrix_handle_t}()
26+
At = SparseMatrixCSC(reverse(A.dims)..., Array(A.rowPtr), Array(A.colVal), Array(A.nzVal))
27+
A_csc = SparseMatrixCSC(At |> transpose)
28+
return A_csc
3029
end
3130
end
3231
end

test/onemkl.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,8 @@ end
10831083
A = sprand(T, 20, 10, 0.5)
10841084
A = SparseMatrixCSC{T, S}(A)
10851085
B = oneSparseMatrixCSR(A)
1086+
A2 = SparseMatrixCSC(B)
1087+
@test A == A2
10861088
end
10871089
end
10881090

0 commit comments

Comments
 (0)