Skip to content

Commit a57e030

Browse files
lkdvosmtfishman
andauthored
Bugfix sparse_mul! handling of beta (#18)
This changes `sparse_mul!` to correctly handle scaling of the destination array. In order to achieve this, the scaling is separated out and handled before the actual multiplication algorithm. --------- Co-authored-by: Matt Fishman <[email protected]>
1 parent f8ad6ae commit a57e030

File tree

4 files changed

+47
-2
lines changed

4 files changed

+47
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseArraysBase"
22
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.7"
4+
version = "0.2.8"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

src/abstractsparsearrayinterface.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,11 +357,15 @@ function sparse_mul!(
357357
β::Number=false;
358358
(mul!!)=(default_mul!!),
359359
)
360+
# TODO: Change to: `a_dest .*= β`
361+
# once https://github.com/ITensor/SparseArraysBase.jl/issues/19 is fixed.
362+
storedvalues(a_dest) .*= β
363+
β′ = one(Bool)
360364
for I1 in eachstoredindex(a1)
361365
for I2 in eachstoredindex(a2)
362366
I_dest = mul_indices(I1, I2)
363367
if !isnothing(I_dest)
364-
a_dest[I_dest] = mul!!(a_dest[I_dest], a1[I1], a2[I2], α, β)
368+
a_dest[I_dest] = mul!!(a_dest[I_dest], a1[I1], a2[I2], α, β)
365369
end
366370
end
367371
end

test/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
55
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
66
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
77
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
8+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
810
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
911
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
12+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1013
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1114
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/basics/test_linalg.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using SparseArraysBase: SparseArrayDOK
2+
using LinearAlgebra: mul!
3+
using Random: Random
4+
using StableRNGs: StableRNG
5+
6+
const rng = StableRNG(123)
7+
8+
# TODO: add this to main package
9+
function sprand(rng::Random.AbstractRNG, ::Type{T}, sz::Base.Dims; p::Real=0.5) where {T}
10+
A = SparseArrayDOK{T}(undef, sz)
11+
for I in eachindex(A)
12+
if rand(rng) < p
13+
A[I] = rand(rng, T)
14+
end
15+
end
16+
return A
17+
end
18+
19+
@testset "mul!" begin
20+
T = Float64
21+
szA = (2, 2)
22+
szB = (2, 2)
23+
szC = (szA[1], szB[2])
24+
25+
for p in 0.0:0.25:1
26+
C = sprand(rng, T, szC; p)
27+
A = sprand(rng, T, szA; p)
28+
B = sprand(rng, T, szB; p)
29+
30+
check1 = mul!(Array(C), Array(A), Array(B))
31+
@test mul!(copy(C), A, B) check1
32+
33+
α = rand(rng, T)
34+
β = rand(rng, T)
35+
check2 = mul!(Array(C), Array(A), Array(B), α, β)
36+
@test mul!(copy(C), A, B, α, β) check2
37+
end
38+
end

0 commit comments

Comments
 (0)