Skip to content

Commit d938744

Browse files
dpinolDani Pinyol
andauthored
Fix "map does not respect sources index type" (#350) (#351)
Co-authored-by: Dani Pinyol <[email protected]>
1 parent 38604c3 commit d938744

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

src/sparsevector.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1380,6 +1380,7 @@ function _binarymap(f::Function,
13801380
mode::Int) where {Tx,Ty}
13811381
0 <= mode <= 2 || throw(ArgumentError("Incorrect mode $mode."))
13821382
R = Base.Broadcast.combine_eltypes(f, (x, y))
1383+
I = promote_type(eltype(nonzeroinds(x)), eltype(nonzeroinds(y)))
13831384
n = length(x)
13841385
length(y) == n || throw(DimensionMismatch())
13851386

@@ -1391,7 +1392,7 @@ function _binarymap(f::Function,
13911392
my = length(ynzind)
13921393
cap = (mode == 0 ? min(mx, my) : mx + my)::Int
13931394

1394-
rind = Vector{Int}(undef, cap)
1395+
rind = Vector{I}(undef, cap)
13951396
rval = Vector{R}(undef, cap)
13961397
ir = 0
13971398
ir = (

test/sparsevector.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,6 +1601,12 @@ end
16011601
@test length(nonzeros(simA)) == 0
16021602
end
16031603

1604+
@testset "map preserves index types" begin
1605+
v1 = spzeros(Float32, Int16, 10)
1606+
v2 = spzeros(Float32, Int32, 10)
1607+
@test eltype(typeof(SparseArrays.nonzeroinds(map(max, v1, v2)))) == Int32
1608+
end
1609+
16041610
@testset "Fast operations on full column views" begin
16051611
n = 1000
16061612
A = sprandn(n, n, 0.01)

0 commit comments

Comments
 (0)