Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ArrayLayouts"
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
authors = ["Sheehan Olver <[email protected]>"]
version = "1.12.1"
version = "1.12.2"

[deps]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand Down
7 changes: 6 additions & 1 deletion src/ArrayLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import Base: axes, size, length, eltype, ndims, first, last, diff, isempty, unio

using Base.Broadcast: Broadcasted

import Base.Broadcast: BroadcastStyle, broadcastable, instantiate, materialize, materialize!
import Base.Broadcast: BroadcastStyle, broadcastable, instantiate, materialize, materialize!, result_style, DefaultArrayStyle

using LinearAlgebra: AbstractQ, AbstractTriangular, AdjOrTrans, AdjointAbsVec, HermOrSym, HessenbergQ, QRCompactWYQ,
QRPackedQ, RealHermSymComplexHerm, TransposeAbsVec, _apply_ipiv_rows!, checknonsingular,
Expand Down Expand Up @@ -430,4 +430,9 @@ Base.typed_hcat(::Type{T}, A::LayoutVecOrMats, B::LayoutVecOrMats, C::AbstractVe
Base.typed_vcat(::Type{T}, A::AbstractVecOrMat, B::LayoutVecOrMats, C::AbstractVecOrMat...) where T = typed_vcat(T, A, B, C...)
Base.typed_hcat(::Type{T}, A::AbstractVecOrMat, B::LayoutVecOrMats, C::AbstractVecOrMat...) where T = typed_hcat(T, A, B, C...)

###
# reshapedarray for layoutarrays
###
BroadcastStyle(::Type{<:ReshapedArray{<:Any, N, P}}) where {N, P<:LayoutArray} = result_style(DefaultArrayStyle{N}(), BroadcastStyle(P))

end # module
25 changes: 25 additions & 0 deletions test/test_layoutarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using ArrayLayouts, LinearAlgebra, FillArrays, Test, SparseArrays, Random
using ArrayLayouts: sub_materialize, ColumnNorm, RowMaximum, CRowMaximum, @_layoutlmul, Mul
import ArrayLayouts: triangulardata, MemoryLayout
import LinearAlgebra: Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle

struct MyMatrix{T,M<:AbstractMatrix{T}} <: LayoutMatrix{T}
A::M
Expand Down Expand Up @@ -38,6 +39,21 @@ Base.unsafe_convert(::Type{Ptr{T}}, A::MyVector{T}) where T = Base.unsafe_conver
MemoryLayout(::Type{MyVector{T,V}}) where {T,V} = MemoryLayout(V)
Base.copy(A::MyVector) = MyVector(copy(A.A))

# These structs are separate since we would need to otherwise implement more Broadcast machinery for MyBroadcastStyle to get the tests below to not error
struct MyBroadcastStyle{N} <: AbstractArrayStyle{N} end
MyBroadcastStyle(::Val{N}) where N = MyBroadcastStyle{N}()
MyBroadcastStyle{M}(::Val{N}) where {N,M} = MyBroadcastStyle{N}()
struct MyMatrix2{T,M<:AbstractMatrix{T}} <: LayoutMatrix{T}
A::M
end
Base.size(A::MyMatrix2) = size(A.A)
BroadcastStyle(::Type{<:MyMatrix2{T}}) where {T} = MyBroadcastStyle{2}()
struct MyVector2{T,V<:AbstractVector{T}} <: LayoutVector{T}
A::V
end
Base.size(A::MyVector2) = size(A.A)
BroadcastStyle(::Type{<:MyVector2{T}}) where {T} = MyBroadcastStyle{1}()

# These need to test dispatch reduces to ArrayLayouts.mul, etc.
@testset "LayoutArray" begin
@testset "LayoutVector" begin
Expand Down Expand Up @@ -69,6 +85,9 @@ Base.copy(A::MyVector) = MyVector(copy(A.A))

s = SparseVector(3, [1], [2])
@test a's == s'a == dot(a,s) == dot(s,a) == dot(s,a.A)

a = MyVector2(a.A)
@test BroadcastStyle(typeof(reshape(a, (1, 3)))) == MyBroadcastStyle{2}()
end

@testset "LayoutMatrix" begin
Expand Down Expand Up @@ -279,6 +298,12 @@ Base.copy(A::MyVector) = MyVector(copy(A.A))
D = Diagonal(MyVector(randn(5)))
@test permutedims(D) ≡ D
end

@testset "BroadcastStyle" begin
A = MyMatrix2(randn(3,6))
@test BroadcastStyle(typeof(reshape(A, (9, 2)))) == MyBroadcastStyle{2}()
@test BroadcastStyle(typeof(vec(A))) == MyBroadcastStyle{2}()
end
end

@testset "l/rmul!" begin
Expand Down
Loading