Skip to content

Commit e171b35

Browse files
committed
fix failed conversion from GBMatrix{Int} to GBMatrix{Float32} #97
1 parent d565180 commit e171b35

File tree

5 files changed

+39
-22
lines changed

5 files changed

+39
-22
lines changed

src/abstractgbarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ Base.eltype(::AbstractGBArray{T, F}) where {T, F} = Union{T, F}
6767
Base.eltype(::Type{<:AbstractGBArray{T, F}}) where{T, F} = Union{T, F}
6868

6969
storedeltype(x) = eltype(x)
70+
storedeltype(::Type{<:AbstractGBArray{T}}) where T = T
7071
storedeltype(::AbstractGBArray{T}) where T = T
7172

7273
Base.unsafe_convert(::Type{LibGraphBLAS.GrB_Matrix}, A::AbstractGBArray) = A.p[]

src/convert.jl

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,37 @@
11
# First we'll just support element type conversions.
22
# This is crucial since we can't pass DataTypes to UDF handlers.
33

4-
# pass through for most cases
5-
conform(M::AbstractGBArray) = M
6-
7-
function Base.convert(::Type{M}, A::N; fill::F = getfill(A)) where {F, M<:AbstractGBArray, N<:AbstractGBArray}
8-
!(F <: Union{Nothing, Missing}) && (fill = convert(eltype(M), fill))
9-
isabstracttype(M) && throw(ArgumentError("$M is an abstract type, which cannot be constructed."))
4+
function applyjl!(F, C::AbstractGBArray, A::AbstractGBArray)
5+
isabstracttype(F) && throw(ArgumentError("$M is an abstract type, which cannot be constructed."))
106
x = tempunpack!(A)
117
repack! = x[end]
128
values = x[end - 1]
139
indices = x[begin:end-2]
14-
newvalues = unsafe_wrap(Array, _sizedjlmalloc(length(values), eltype(M)), size(values))
15-
copyto!(newvalues, values)
10+
newvalues = unsafe_wrap(Array, _sizedjlmalloc(length(values), storedeltype(C)), size(values))
11+
map!(F, newvalues, values)
1612
newindices = _copytoraw.(indices)
1713
repack!()
14+
unsafepack!(C, newindices..., newvalues, false; decrementindices = false, order = storageorder(A))
15+
return C
16+
end
17+
18+
function Base.convert(::Type{M}, A::N; fill::F = getfill(A)) where {F, M<:AbstractGBArray, N<:AbstractGBArray}
19+
!(F <: Union{Nothing, Missing}) && (fill = convert(storedeltype(M), fill))
1820
B = M(size(A, 1), size(A, 2); fill)
19-
unsafepack!(B, newindices..., newvalues, false; decrementindices = false, order = storageorder(A))
21+
applyjl!(storedeltype(B), B, A)
22+
end
23+
function Base.convert(::Type{M}, A::N; fill::F = getfill(A)) where {F, M<:AbstractGBVector, N<:AbstractGBVector}
24+
!(F <: Union{Nothing, Missing}) && (fill = convert(eltype(M), fill))
25+
B = M(size(A, 1); fill)
26+
applyjl!(storedeltype(B), B, A)
2027
end
2128

2229
Base.convert(::Type{M}, A::M; fill = nothing) where {M<:AbstractGBArray} = A
2330

2431
function LinearAlgebra.copy_oftype(A::GBArrayOrTranspose, ::Type{T}) where T
2532
order = storageorder(A)
2633
C = similar(A, T, size(A))
27-
x = tempunpack!(A)
28-
repack! = x[end]
29-
values = x[end - 1]
30-
indices = x[begin:end-2]
31-
newvalues = unsafe_wrap(Array, _sizedjlmalloc(length(values), T), size(values))
32-
copyto!(newvalues, values)
33-
newindices = _copytoraw.(indices)
34-
repack!()
35-
unsafepack!(C, newindices..., newvalues, false; order, decrementindices = false)
34+
applyjl!(T, C, A)
3635
end
3736
# TODO: Implement this?
3837
Base.convert(::Type{M}, ::AbstractGBArray; fill = nothing) where {M<:AbstractGBShallowArray} =

src/operations/map.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@ function apply!(
1515
return C
1616
end
1717

18+
function apply!(
19+
op::DataType, C::GBVecOrMat, A::GBArrayOrTranspose;
20+
mask = nothing, accum = nothing, desc = nothing
21+
)
22+
(mask !== nothing || accum !== nothing || desc !== nothing) &&
23+
throw(ArgumentError("Cannot apply! a DataType with a mask, accum, and desc."))
24+
return applyjl!(op, C, A)
25+
end
26+
1827
function apply!(
1928
op, A::GBArrayOrTranspose{T};
2029
mask = nothing, accum = nothing, desc = nothing

src/unpack.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,8 @@ function _unpackhypermatrix!(
292292
isiso == C_NULL && (isiso = false)
293293
nvec = nvec[]
294294
colptr = unsafe_wrap(Array, Ptr{Int64}(colptr[]), nvec + 1)
295-
colidx = unsafe_wrap(Array, Ptr{Int64}(colidx), nvec)
296-
rowidx = unsafe_wrap(Array, Ptr{Int64}(rowidx), nnonzeros)
295+
colidx = unsafe_wrap(Array, Ptr{Int64}(colidx[]), nvec)
296+
rowidx = unsafe_wrap(Array, Ptr{Int64}(rowidx[]), nnonzeros)
297297
nstored = isiso[] ? 1 : nnonzeros
298298
vals = unsafe_wrap(Array, Ptr{T}(values[]), nstored)
299299

@@ -524,4 +524,4 @@ function tempunpack!(A::AbstractGBArray, incrementindices = false)
524524
end
525525

526526
# TODO: BITMAP && HYPER
527-
# TODO: A reunsafepack! api?
527+
# TODO: A reunsafepack! api?

test/issues.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,12 @@
6767
x = GBMatrix([1,2], [2, 3], [1,2], fill=0)
6868
@test Array(x) .* [1,2] x .* [1,2]
6969
end
70-
end
70+
@testset "#97" begin
71+
x = GBMatrix([1,2], [2, 3], [1,2], fill=0)
72+
@test SuiteSparseGraphBLAS.storedeltype(Float32.(x)) === Float32
73+
@test SuiteSparseGraphBLAS.storedeltype(map(Float32, x)) === Float32
74+
@test SuiteSparseGraphBLAS.storedeltype(convert(GBMatrix{Float32}, x)) === Float32
75+
@test SuiteSparseGraphBLAS.storedeltype(convert(GBMatrix{Float32, Float32}, x)) === Float32
76+
@test SuiteSparseGraphBLAS.storedeltype(Float32.(x .> 0)) === Float32
77+
end
78+
end

0 commit comments

Comments
 (0)