Skip to content

Commit dfe4fd8

Browse files
authored
Catch more cases of sparse matmul involving adjoints (#60)
1 parent 5e42270 commit dfe4fd8

File tree

3 files changed

+54
-5
lines changed

3 files changed

+54
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseArraysBase"
22
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.8"
4+
version = "0.5.9"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

src/abstractsparsearray.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,43 @@ function sparserand!(
152152
A[I] = v
153153
end
154154
end
155+
156+
# Catch some cases that aren't getting caught by the current
157+
# DerivableInterfaces.jl logic.
158+
# TODO: Make this more systematic once DerivableInterfaces.jl
159+
# is rewritten.
160+
using ArrayLayouts: ArrayLayouts, MemoryLayout
161+
using LinearAlgebra: LinearAlgebra, Adjoint
162+
function ArrayLayouts.MemoryLayout(::Type{Transpose{T,P}}) where {T,P<:AbstractSparseMatrix}
163+
return MemoryLayout(P)
164+
end
165+
function ArrayLayouts.MemoryLayout(::Type{Adjoint{T,P}}) where {T,P<:AbstractSparseMatrix}
166+
return MemoryLayout(P)
167+
end
168+
function LinearAlgebra.mul!(
169+
dest::AbstractMatrix,
170+
A::Adjoint{<:Any,<:AbstractSparseMatrix},
171+
B::AbstractSparseMatrix,
172+
α::Number,
173+
β::Number,
174+
)
175+
return ArrayLayouts.mul!(dest, A, B, α, β)
176+
end
177+
function LinearAlgebra.mul!(
178+
dest::AbstractMatrix,
179+
A::AbstractSparseMatrix,
180+
B::Adjoint{<:Any,<:AbstractSparseMatrix},
181+
α::Number,
182+
β::Number,
183+
)
184+
return ArrayLayouts.mul!(dest, A, B, α, β)
185+
end
186+
function LinearAlgebra.mul!(
187+
dest::AbstractMatrix,
188+
A::Adjoint{<:Any,<:AbstractSparseMatrix},
189+
B::Adjoint{<:Any,<:AbstractSparseMatrix},
190+
α::Number,
191+
β::Number,
192+
)
193+
return ArrayLayouts.mul!(dest, A, B, α, β)
194+
end

test/test_linalg.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,22 @@ const rng = StableRNG(123)
1616
A = sparserand(rng, T, szA; density)
1717
B = sparserand(rng, T, szB; density)
1818

19-
check1 = mul!(Array(C), Array(A), Array(B))
20-
@test mul!(copy(C), A, B) check1
19+
check = mul!(Array(C), Array(A), Array(B))
20+
@test mul!(copy(C), A, B) check
21+
22+
check = mul!(Array(C), Array(A)', Array(B))
23+
@test mul!(copy(C), A', B) check
24+
25+
check = mul!(Array(C), Array(A), Array(B)')
26+
@test mul!(copy(C), A, B') check
27+
28+
check = mul!(Array(C), Array(A)', Array(B)')
29+
@test mul!(copy(C), A', B') check
2130

2231
α = rand(rng, T)
2332
β = rand(rng, T)
24-
check2 = mul!(Array(C), Array(A), Array(B), α, β)
25-
@test mul!(copy(C), A, B, α, β) check2
33+
check = mul!(Array(C), Array(A), Array(B), α, β)
34+
@test mul!(copy(C), A, B, α, β) check
2635
end
2736

2837
# test empty matrix

0 commit comments

Comments
 (0)