diff --git a/Project.toml b/Project.toml index fc6a88d..8ea4c0b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ArrayLayouts" uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" authors = ["Sheehan Olver "] -version = "1.12.1" +version = "1.12.2" [deps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" diff --git a/src/ArrayLayouts.jl b/src/ArrayLayouts.jl index 6f3d0bf..e928c3a 100644 --- a/src/ArrayLayouts.jl +++ b/src/ArrayLayouts.jl @@ -13,7 +13,7 @@ import Base: axes, size, length, eltype, ndims, first, last, diff, isempty, unio using Base.Broadcast: Broadcasted -import Base.Broadcast: BroadcastStyle, broadcastable, instantiate, materialize, materialize! +import Base.Broadcast: BroadcastStyle, broadcastable, instantiate, materialize, materialize!, result_style, DefaultArrayStyle using LinearAlgebra: AbstractQ, AbstractTriangular, AdjOrTrans, AdjointAbsVec, HermOrSym, HessenbergQ, QRCompactWYQ, QRPackedQ, RealHermSymComplexHerm, TransposeAbsVec, _apply_ipiv_rows!, checknonsingular, @@ -430,4 +430,9 @@ Base.typed_hcat(::Type{T}, A::LayoutVecOrMats, B::LayoutVecOrMats, C::AbstractVe Base.typed_vcat(::Type{T}, A::AbstractVecOrMat, B::LayoutVecOrMats, C::AbstractVecOrMat...) where T = typed_vcat(T, A, B, C...) Base.typed_hcat(::Type{T}, A::AbstractVecOrMat, B::LayoutVecOrMats, C::AbstractVecOrMat...) where T = typed_hcat(T, A, B, C...) +### +# reshapedarray for layoutarrays +### +BroadcastStyle(::Type{<:ReshapedArray{<:Any, N, P}}) where {N, P<:LayoutArray} = result_style(DefaultArrayStyle{N}(), BroadcastStyle(P)) + end # module diff --git a/test/test_layoutarray.jl b/test/test_layoutarray.jl index b044606..7f1a4c9 100644 --- a/test/test_layoutarray.jl +++ b/test/test_layoutarray.jl @@ -4,6 +4,7 @@ using ArrayLayouts, LinearAlgebra, FillArrays, Test, SparseArrays, Random using ArrayLayouts: sub_materialize, ColumnNorm, RowMaximum, CRowMaximum, @_layoutlmul, Mul import ArrayLayouts: triangulardata, MemoryLayout import LinearAlgebra: Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal +import Base.Broadcast: BroadcastStyle, AbstractArrayStyle struct MyMatrix{T,M<:AbstractMatrix{T}} <: LayoutMatrix{T} A::M @@ -38,6 +39,21 @@ Base.unsafe_convert(::Type{Ptr{T}}, A::MyVector{T}) where T = Base.unsafe_conver MemoryLayout(::Type{MyVector{T,V}}) where {T,V} = MemoryLayout(V) Base.copy(A::MyVector) = MyVector(copy(A.A)) +# These structs are separate since we would need to otherwise implement more Broadcast machinery for MyBroadcastStyle to get the tests below to not error +struct MyBroadcastStyle{N} <: AbstractArrayStyle{N} end +MyBroadcastStyle(::Val{N}) where N = MyBroadcastStyle{N}() +MyBroadcastStyle{M}(::Val{N}) where {N,M} = MyBroadcastStyle{N}() +struct MyMatrix2{T,M<:AbstractMatrix{T}} <: LayoutMatrix{T} + A::M +end +Base.size(A::MyMatrix2) = size(A.A) +BroadcastStyle(::Type{<:MyMatrix2{T}}) where {T} = MyBroadcastStyle{2}() +struct MyVector2{T,V<:AbstractVector{T}} <: LayoutVector{T} + A::V +end +Base.size(A::MyVector2) = size(A.A) +BroadcastStyle(::Type{<:MyVector2{T}}) where {T} = MyBroadcastStyle{1}() + # These need to test dispatch reduces to ArrayLayouts.mul, etc. @testset "LayoutArray" begin @testset "LayoutVector" begin @@ -69,6 +85,9 @@ Base.copy(A::MyVector) = MyVector(copy(A.A)) s = SparseVector(3, [1], [2]) @test a's == s'a == dot(a,s) == dot(s,a) == dot(s,a.A) + + a = MyVector2(a.A) + @test BroadcastStyle(typeof(reshape(a, (1, 3)))) == MyBroadcastStyle{2}() end @testset "LayoutMatrix" begin @@ -279,6 +298,12 @@ Base.copy(A::MyVector) = MyVector(copy(A.A)) D = Diagonal(MyVector(randn(5))) @test permutedims(D) ≡ D end + + @testset "BroadcastStyle" begin + A = MyMatrix2(randn(3,6)) + @test BroadcastStyle(typeof(reshape(A, (9, 2)))) == MyBroadcastStyle{2}() + @test BroadcastStyle(typeof(vec(A))) == MyBroadcastStyle{2}() + end end @testset "l/rmul!" begin