Skip to content

Commit 854355d

Browse files
committed
new as interface
1 parent 6d446ff commit 854355d

File tree

1 file changed

+11
-43
lines changed

1 file changed

+11
-43
lines changed

src/asjulia.jl

Lines changed: 11 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
function asArray(f::Function, A::GBVecOrMat{T}; dropzeros=false, freeunpacked=false) where {T}
1+
function as(f::Function, ::Type{<:Union{Matrix, Vector}}, A::GBVecOrMat{T}; dropzeros=false, freeunpacked=false) where {T}
22
if gbget(A, SPARSITY_STATUS) != GBDENSE
33
X = similar(A)
44
if X isa GBVector
55
X[:] = zero(T)
66
else
77
X[:,:] = zero(T)
88
end
9-
#I don't like this, it defeats the purpose of this method, which is to make no copies.
9+
# I don't like this, it defeats the purpose of this method, which is to make no copies.
1010
# But somehow maintaining the input A in its original form is key to the to_vec implementation
11-
# for ChainRules. Temporarily it's fine, it's no worse than it originally was.
11+
# for ChainRules/FiniteDiff. Temporarily it's fine, it's no worse than it originally was.
1212
# TODO: fix this issue with the ChainRules code.
1313
A = eadd(X, A)
1414
end
@@ -28,39 +28,7 @@ function asArray(f::Function, A::GBVecOrMat{T}; dropzeros=false, freeunpacked=fa
2828
return result
2929
end
3030

31-
function asCSCVectors(f::Function, A::GBMatrix{T}; freeunpacked=false) where {T}
32-
colptr, rowidx, values = _unpackcscmatrix!(A)
33-
result = try
34-
f(colptr, rowidx, values, A)
35-
finally
36-
if freeunpacked
37-
ccall(:jl_free, Cvoid, (Ptr{LibGraphBLAS.GrB_Index},), pointer(colptr))
38-
ccall(:jl_free, Cvoid, (Ptr{LibGraphBLAS.GrB_Index},), pointer(rowidx))
39-
ccall(:jl_free, Cvoid, (Ptr{T},), pointer(values))
40-
else
41-
_packcscmatrix!(A, colptr, rowidx, values)
42-
end
43-
end
44-
return result
45-
end
46-
47-
function asCSRVectors(f::Function, A::GBMatrix{T}; freeunpacked=false) where {T}
48-
rowptr, colidx, values = _unpackcsrmatrix!(A)
49-
result = try
50-
f(rowptr, colidx, values, A)
51-
finally
52-
if freeunpacked
53-
ccall(:jl_free, Cvoid, (Ptr{LibGraphBLAS.GrB_Index},), pointer(rowptr))
54-
ccall(:jl_free, Cvoid, (Ptr{LibGraphBLAS.GrB_Index},), pointer(colidx))
55-
ccall(:jl_free, Cvoid, (Ptr{T},), pointer(values))
56-
else
57-
_packcsrmatrix!(A, rowptr, colidx, values)
58-
end
59-
end
60-
return result
61-
end
62-
63-
function asSparseMatrixCSC(f::Function, A::GBMatrix{T}; freeunpacked=false) where {T}
31+
function as(f::Function, ::SparseMatrixCSC, A::GBMatrix{T}; freeunpacked=false) where {T}
6432
colptr, rowidx, values = _unpackcscmatrix!(A)
6533
array = SparseMatrixCSC{T, LibGraphBLAS.GrB_Index}(size(A, 1), size(A, 2), colptr, rowidx, values)
6634
result = try
@@ -77,7 +45,7 @@ function asSparseMatrixCSC(f::Function, A::GBMatrix{T}; freeunpacked=false) wher
7745
return result
7846
end
7947

80-
function asSparseVector(f::Function, A::GBVector{T}; freeunpacked=false) where {T}
48+
function as(f::Function, ::SparseVector, A::GBVector{T}; freeunpacked=false) where {T}
8149
colptr, rowidx, values = _unpackcscmatrix!(A)
8250
vector = SparseVector{T, LibGraphBLAS.GrB_Index}(size(A, 1), rowidx, values)
8351
result = try
@@ -96,37 +64,37 @@ end
9664

9765

9866
function Base.Matrix(A::GBMatrix)
99-
return asArray(A) do arr, _
67+
return as(Matrix, A) do arr, _
10068
return copy(arr)
10169
end
10270
end
10371

10472
function Matrix!(A::GBMatrix)
105-
return asArray(A; freeunpacked=true) do arr, _
73+
return as(Matrix, A; freeunpacked=true) do arr, _
10674
return copy(arr)
10775
end
10876
end
10977

11078
function Base.Vector(v::GBVector)
111-
return asArray(v) do vec, _
79+
return as(Vector, v) do vec, _
11280
return copy(vec)
11381
end
11482
end
11583

11684
function Vector!(v::GBVector)
117-
return asArray(v; freeunpacked=true) do vec, _
85+
return as(Vector, v; freeunpacked=true) do vec, _
11886
return copy(vec)
11987
end
12088
end
12189

12290
function SparseArrays.SparseMatrixCSC(A::GBMatrix)
123-
return asArray(A) do arr, _
91+
return as(SparseMatrixCSC, A) do arr, _
12492
return copy(arr)
12593
end
12694
end
12795

12896
function SparseArrays.SparseVector(v::GBVector)
129-
return asArray(v) do arr, _
97+
return as(SparseVector, v) do arr, _
13098
return copy(arr)
13199
end
132100
end

0 commit comments

Comments
 (0)