Skip to content

Commit da39287

Browse files
abraunstandreasnoack
authored andcommitted
make SparseMatrixCSC and SparseVector work on non-numerical values (#30580)
* make SparseMatrixCSC and SparseVector work on non-numerical values * remove tests for sparse structures with TV=String as it is no longer supported * define zero(T) and zero(x::T) for user defined types in tests * fix test * Remove recently added test of sparse matrix with string elements
1 parent 85603e1 commit da39287

File tree

6 files changed

+19
-28
lines changed

6 files changed

+19
-28
lines changed

stdlib/SparseArrays/src/sparsematrix.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ function SparseMatrixCSC{Tv,Ti}(M::AbstractMatrix) where {Tv,Ti}
385385
end
386386

387387
function SparseMatrixCSC{Tv,Ti}(M::StridedMatrix) where {Tv,Ti}
388-
nz = count(t -> t != 0, M)
388+
nz = count(!iszero, M)
389389
colptr = zeros(Ti, size(M, 2) + 1)
390390
nzval = Vector{Tv}(undef, nz)
391391
rowval = Vector{Ti}(undef, nz)
@@ -394,7 +394,7 @@ function SparseMatrixCSC{Tv,Ti}(M::StridedMatrix) where {Tv,Ti}
394394
@inbounds for j in 1:size(M, 2)
395395
for i in 1:size(M, 1)
396396
v = M[i, j]
397-
if v != 0
397+
if !iszero(v)
398398
rowval[cnt] = i
399399
nzval[cnt] = v
400400
cnt += 1
@@ -1241,7 +1241,7 @@ Removes stored numerical zeros from `A`, optionally trimming resulting excess sp
12411241
For an out-of-place version, see [`dropzeros`](@ref). For
12421242
algorithmic information, see `fkeep!`.
12431243
"""
1244-
dropzeros!(A::SparseMatrixCSC; trim::Bool = true) = fkeep!(A, (i, j, x) -> x != 0, trim)
1244+
dropzeros!(A::SparseMatrixCSC; trim::Bool = true) = fkeep!(A, (i, j, x) -> !iszero(x), trim)
12451245
"""
12461246
dropzeros(A::SparseMatrixCSC; trim::Bool = true)
12471247
@@ -2330,7 +2330,7 @@ function _setindex_scalar!(A::SparseMatrixCSC{Tv,Ti}, _v, _i::Integer, _j::Integ
23302330
end
23312331
# Column j does not contain entry A[i,j]. If v is nonzero, insert entry A[i,j] = v
23322332
# and return. If to the contrary v is zero, then simply return.
2333-
if v != 0
2333+
if !iszero(v)
23342334
insert!(A.rowval, searchk, i)
23352335
insert!(A.nzval, searchk, v)
23362336
@simd for m in (j + 1):(A.n + 1)
@@ -3184,7 +3184,7 @@ function is_hermsym(A::SparseMatrixCSC, check::Function)
31843184
# We therefore "catch up" here while making sure that
31853185
# the elements are actually zero.
31863186
while row2 < col
3187-
if nzval[offset] != 0
3187+
if !iszero(nzval[offset])
31883188
return false
31893189
end
31903190
offset += 1
@@ -3222,7 +3222,7 @@ function istriu(A::SparseMatrixCSC)
32223222
if rowval[l1-i] <= col
32233223
break
32243224
end
3225-
if nzval[l1-i] != 0
3225+
if !iszero(nzval[l1-i])
32263226
return false
32273227
end
32283228
end
@@ -3241,7 +3241,7 @@ function istril(A::SparseMatrixCSC)
32413241
if rowval[i] >= col
32423242
break
32433243
end
3244-
if nzval[i] != 0
3244+
if !iszero(nzval[i])
32453245
return false
32463246
end
32473247
end

stdlib/SparseArrays/src/sparsevector.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ function setindex!(x::SparseVector{Tv,Ti}, v::Tv, i::Ti) where {Tv,Ti<:Integer}
300300
if 1 <= k <= m && nzind[k] == i # i found
301301
nzval[k] = v
302302
else # i not found
303-
if v != 0
303+
if !iszero(v)
304304
insert!(nzind, k, i)
305305
insert!(nzval, k, v)
306306
end
@@ -392,7 +392,7 @@ function _dense2indval!(nzind::Vector{Ti}, nzval::Vector{Tv}, s::AbstractArray{T
392392
c = 0
393393
@inbounds for i = 1:n
394394
v = s[i]
395-
if v != 0
395+
if !iszero(v)
396396
if c >= cap
397397
cap *= 2
398398
resize!(nzind, cap)
@@ -1929,7 +1929,7 @@ Removes stored numerical zeros from `x`, optionally trimming resulting excess sp
19291929
For an out-of-place version, see [`dropzeros`](@ref). For
19301930
algorithmic information, see `fkeep!`.
19311931
"""
1932-
dropzeros!(x::SparseVector; trim::Bool = true) = fkeep!(x, (i, x) -> x != 0, trim)
1932+
dropzeros!(x::SparseVector; trim::Bool = true) = fkeep!(x, (i, x) -> !iszero(x), trim)
19331933

19341934
"""
19351935
dropzeros(x::SparseVector; trim::Bool = true)

stdlib/SparseArrays/test/higherorderfns.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,17 +287,14 @@ end
287287
A, fA = sparse(1.0I, N, N), Matrix(1.0I, N, N)
288288
B, fB = spzeros(1, N), zeros(1, N)
289289
intorfloat_zeropres(xs...) = all(iszero, xs) ? zero(Float64) : Int(1)
290-
stringorfloat_zeropres(xs...) = all(iszero, xs) ? zero(Float64) : "hello"
291290
intorfloat_notzeropres(xs...) = all(iszero, xs) ? Int(1) : zero(Float64)
292-
stringorfloat_notzeropres(xs...) = all(iszero, xs) ? "hello" : zero(Float64)
293-
for fn in (intorfloat_zeropres, intorfloat_notzeropres,
294-
stringorfloat_zeropres, stringorfloat_notzeropres)
291+
for fn in (intorfloat_zeropres, intorfloat_notzeropres)
295292
@test map(fn, A) == sparse(map(fn, fA))
296293
@test broadcast(fn, A) == sparse(broadcast(fn, fA))
297294
@test broadcast(fn, A, B) == sparse(broadcast(fn, fA, fB))
298295
@test broadcast(fn, B, A) == sparse(broadcast(fn, fB, fA))
299296
end
300-
for fn in (intorfloat_zeropres, stringorfloat_zeropres)
297+
for fn in (intorfloat_zeropres,)
301298
@test broadcast(fn, A, B, A) == sparse(broadcast(fn, fA, fB, fA))
302299
end
303300
end

stdlib/SparseArrays/test/sparse.jl

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,15 +1263,16 @@ end
12631263
@test isequal(findmax(A, dims=tup), (rval, rind))
12641264
end
12651265

1266-
A = sparse(["a", "b"])
1267-
@test_throws MethodError findmin(A, dims=1)
1266+
# sparse arrays of types without zero(T) are forbidden
1267+
@test_throws MethodError sparse(["a", "b"])
12681268
end
12691269

12701270
# Support the case when user defined `zero` and `isless` for non-numerical type
12711271
struct CustomType
12721272
x::String
12731273
end
12741274
Base.zero(::Type{CustomType}) = CustomType("")
1275+
Base.zero(x::CustomType) = zero(CustomType)
12751276
Base.isless(x::CustomType, y::CustomType) = isless(x.x, y.x)
12761277
@testset "findmin/findmax for non-numerical type" begin
12771278
A = sparse([CustomType("a"), CustomType("b")])
@@ -2286,17 +2287,6 @@ end
22862287
@test findnext(!iszero, z,i) == findnext(!iszero, z_sp,i)
22872288
@test findprev(!iszero, z,i) == findprev(!iszero, z_sp,i)
22882289
end
2289-
2290-
w = [ "a" ""; "" "b"]
2291-
w_sp = sparse(w)
2292-
2293-
for i in keys(w)
2294-
@test findnext(!isequal(""), w,i) == findnext(!isequal(""), w_sp,i)
2295-
@test findprev(!isequal(""), w,i) == findprev(!isequal(""), w_sp,i)
2296-
@test findnext(isequal(""), w,i) == findnext(isequal(""), w_sp,i)
2297-
@test findprev(isequal(""), w,i) == findprev(isequal(""), w_sp,i)
2298-
end
2299-
23002290
end
23012291

23022292
# #20711

test/hashing.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ Base.hash(x::CustomHashReal, h::UInt) = hash(x.x, h)
159159
Base.:(==)(x::CustomHashReal, y::Number) = x.x == y
160160
Base.:(==)(x::Number, y::CustomHashReal) = x == y.x
161161
Base.zero(::Type{CustomHashReal}) = CustomHashReal(0.0)
162+
Base.zero(x::CustomHashReal) = zero(CustomHashReal)
162163

163164
let a = sparse([CustomHashReal(0), CustomHashReal(3), CustomHashReal(3)])
164165
@test hash(a) == hash(Array(a))

test/show.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,9 @@ end
560560

561561
# issue #12960
562562
mutable struct T12960 end
563+
import Base.zero
564+
Base.zero(::Type{T12960}) = T12960()
565+
Base.zero(x::T12960) = T12960()
563566
let
564567
A = sparse(1.0I, 3, 3)
565568
B = similar(A, T12960)

0 commit comments

Comments
 (0)