Skip to content

Commit f9fe08b

Browse files
authored
Fixes for adjtrans muladd, and map(copy, ::Diagonal) (#52)
* Fix MulAdd([1,2], [[1,2]',[2,3]') * Update muladd.jl * map(copy, ::Diagonal), using in BlockArrays * Update runtests.jl
1 parent 5111eb1 commit f9fe08b

File tree

5 files changed

+31
-3
lines changed

5 files changed

+31
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "0.5.1"
4+
version = "0.5.2"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/ArrayLayouts.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ if VERSION ≥ v"1.5"
189189
copyto!(dest::SubArray{<:Any,2,<:LayoutMatrix}, src::SparseArrays.AbstractSparseMatrixCSC) = _copyto!(dest, src)
190190
end
191191

192+
# avoid bad copy in Base
193+
Base.map(::typeof(copy), D::Diagonal{<:LayoutArray}) = Diagonal(map(copy, D.diag))
194+
192195
zero!(A::AbstractArray{T}) where T = fill!(A,zero(T))
193196
function zero!(A::AbstractArray{<:AbstractArray})
194197
for a in A

src/muladd.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,17 @@ copy(M::MulAdd{<:AbstractFillLayout,<:AbstractFillLayout,<:AbstractFillLayout})
399399
# DualLayout
400400
###
401401

402+
transtype(::Adjoint) = adjoint
403+
transtype(::Transpose) = transpose
404+
402405
function similar(M::MulAdd{<:DualLayout,<:Any,ZerosLayout}, ::Type{T}, (x,y)) where T
403-
@assert x Base.OneTo(1)
404-
similar(M.A', T, y)'
406+
@assert length(x) == 1
407+
trans = transtype(M.A)
408+
trans(similar(trans(M.A), T, y))
409+
end
410+
411+
function similar(M::MulAdd{<:Any,<:DualLayout,ZerosLayout}, ::Type{T}, (x,y)) where T
412+
@assert length(x) == 1
413+
trans = transtype(M.B)
414+
trans(similar(trans(M.B), T, y))
405415
end

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
152152
@test x' * Symmetric(MyMatrix(A)) x'Symmetric(A)
153153
@test transpose(x) * Symmetric(MyMatrix(A)) transpose(x)Symmetric(A)
154154
end
155+
156+
@testset "map(copy, ::Diagonal)" begin
157+
# this is needed in BlockArrays
158+
D = Diagonal([MyMatrix(randn(2,2)), MyMatrix(randn(2,2))])
159+
@test map(copy, D) == D
160+
end
155161
end
156162

157163
@testset "l/rmul!" begin

test/test_muladd.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,4 +660,13 @@ Random.seed!(0)
660660
@test ArrayLayouts.dot(a,b) == mul(a',b) == dot(a,b)
661661
@test eltype(Dot(a,1:5)) == Float64
662662
end
663+
664+
@testset "adjtrans muladd" begin
665+
A,B = [1 2], [[1,2]',[3,4]']
666+
= [transpose([1,2]), transpose([3,4])]
667+
@test copy(MulAdd(A,B)) == A*B
668+
@test eltype(MulAdd(A,B)) == eltype(B)
669+
@test copy(MulAdd(A,B̃)) == A*
670+
@test eltype(MulAdd(A,B̃)) == eltype(B̃)
671+
end
663672
end

0 commit comments

Comments
 (0)