Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/abstractsparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,13 @@ function sparse_mul!(
β::Number=false;
(mul!!)=(default_mul!!),
)
storedvalues(a_dest) .*= β
β′ = one(Bool)
for I1 in eachstoredindex(a1)
for I2 in eachstoredindex(a2)
I_dest = mul_indices(I1, I2)
if !isnothing(I_dest)
a_dest[I_dest] = mul!!(a_dest[I_dest], a1[I1], a2[I2], α, β)
a_dest[I_dest] = mul!!(a_dest[I_dest], a1[I1], a2[I2], α, β)
end
end
end
Expand Down
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
38 changes: 38 additions & 0 deletions test/basics/test_linalg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using SparseArraysBase: SparseArrayDOK
using LinearAlgebra: mul!
using Random: Random
using StableRNGs: StableRNG

const rng = StableRNG(123)

# TODO: add this to main package
function sprand(rng::Random.AbstractRNG, ::Type{T}, sz::Base.Dims; p::Real=0.5) where {T}
A = SparseArrayDOK{T}(undef, sz)
for I in eachindex(A)
if rand(rng) < p
A[I] = rand(rng, T)
end
end
return A
end

@testset "mul!" begin
T = Float64
szA = (2, 2)
szB = (2, 2)
szC = (szA[1], szB[2])

for p in 0.0:0.25:1
C = sprand(rng, T, szC; p)
A = sprand(rng, T, szA; p)
B = sprand(rng, T, szB; p)

check1 = mul!(Array(C), Array(A), Array(B))
@test mul!(copy(C), A, B) ≈ check1

α = rand(rng, T)
β = rand(rng, T)
check2 = mul!(Array(C), Array(A), Array(B), α, β)
@test mul!(copy(C), A, B, α, β) ≈ check2
end
end
Loading