Skip to content

Commit 8c6a370

Browse files
author
Dani Pinyol
committed
Fix keytype and zero
1 parent 485fd4b commit 8c6a370

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

src/SparseArrays.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import LinearAlgebra: mul!, ldiv!, rdiv!, cholesky, adjoint!, diag, eigen, dot,
2222

2323
import Base: adjoint, argmin, argmax, Array, broadcast, circshift!, complex, Complex,
2424
conj, conj!, convert, copy, copy!, copyto!, count, diff, findall, findmax, findmin,
25-
float, getindex, imag, inv, kron, kron!, length, map, maximum, minimum, permute!, real,
25+
float, getindex, imag, inv, keytype, kron, kron!, length, map, maximum, minimum, permute!, real,
2626
rot180, rotl90, rotr90, setindex!, show, similar, size, sum, transpose,
2727
vcat, hcat, hvcat, cat, vec, reverse, reverse!
2828

@@ -84,7 +84,8 @@ if Base.USE_GPL_LIBS
8484
include("solvers/spqr.jl")
8585
end
8686

87-
zero(a::AbstractSparseArray) = spzeros(eltype(a), size(a)...)
87+
keytype(::Type{A}) where {Tv, Ti, A<:AbstractSparseArray{Tv,Ti}} = Ti
88+
zero(a::AbstractSparseArray) = spzeros(eltype(a), keytype(a), size(a)...)
8889

8990
LinearAlgebra.diagzero(D::Diagonal{<:AbstractSparseMatrix{T}},i,j) where {T} =
9091
spzeros(T, size(D.diag[i], 1), size(D.diag[j], 2))

test/sparsevector.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ include("forbidproperties.jl")
1212
### Data
1313

1414
spv_x1 = SparseVector(8, [2, 5, 6], [1.25, -0.75, 3.5])
15+
spv_x1_32 = SparseVector(8, Int32[2, 5, 6], Float32[1.25, -0.75, 3.5])
1516

1617
@test isa(spv_x1, SparseVector{Float64,Int})
1718

@@ -42,6 +43,14 @@ x1_full[SparseArrays.nonzeroinds(spv_x1)] = nonzeros(spv_x1)
4243
@test @inferred size(y) == (@inferred(length(y))::Int8,)
4344
end
4445

46+
@testset "Non default index type" begin
47+
x = spv_x1_32
48+
for func in [identity, copy, empty, similar, zero]
49+
@test eltype(func(spv_x1_32)) == Float32
50+
@test keytype(func(spv_x1_32)) == Int32
51+
end
52+
end
53+
4554
@testset "isstored" begin
4655
x = spv_x1
4756
stored_inds = [2, 5, 6]

0 commit comments

Comments
 (0)