Skip to content

Commit 3ea20ae

Browse files
committed
Add method
1 parent 1e4fcc1 commit 3ea20ae

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "1.12.1"
4+
version = "1.12.2"
55

66
[deps]
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

src/ArrayLayouts.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import Base: axes, size, length, eltype, ndims, first, last, diff, isempty, unio
1313

1414
using Base.Broadcast: Broadcasted
1515

16-
import Base.Broadcast: BroadcastStyle, broadcastable, instantiate, materialize, materialize!
16+
import Base.Broadcast: BroadcastStyle, broadcastable, instantiate, materialize, materialize!, result_style, DefaultArrayStyle
1717

1818
using LinearAlgebra: AbstractQ, AbstractTriangular, AdjOrTrans, AdjointAbsVec, HermOrSym, HessenbergQ, QRCompactWYQ,
1919
QRPackedQ, RealHermSymComplexHerm, TransposeAbsVec, _apply_ipiv_rows!, checknonsingular,
@@ -430,4 +430,9 @@ Base.typed_hcat(::Type{T}, A::LayoutVecOrMats, B::LayoutVecOrMats, C::AbstractVe
430430
Base.typed_vcat(::Type{T}, A::AbstractVecOrMat, B::LayoutVecOrMats, C::AbstractVecOrMat...) where T = typed_vcat(T, A, B, C...)
431431
Base.typed_hcat(::Type{T}, A::AbstractVecOrMat, B::LayoutVecOrMats, C::AbstractVecOrMat...) where T = typed_hcat(T, A, B, C...)
432432

433+
###
434+
# reshapedarray for layoutarrays
435+
###
436+
BroadcastStyle(::Type{<:ReshapedArray{<:Any, N, P}}) where {N, P<:LayoutArray} = result_style(DefaultArrayStyle{N}(), BroadcastStyle(P))
437+
433438
end # module

test/test_layoutarray.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using ArrayLayouts, LinearAlgebra, FillArrays, Test, SparseArrays, Random
44
using ArrayLayouts: sub_materialize, ColumnNorm, RowMaximum, CRowMaximum, @_layoutlmul, Mul
55
import ArrayLayouts: triangulardata, MemoryLayout
66
import LinearAlgebra: Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal
7+
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle
78

89
struct MyMatrix{T,M<:AbstractMatrix{T}} <: LayoutMatrix{T}
910
A::M
@@ -38,6 +39,21 @@ Base.unsafe_convert(::Type{Ptr{T}}, A::MyVector{T}) where T = Base.unsafe_conver
3839
MemoryLayout(::Type{MyVector{T,V}}) where {T,V} = MemoryLayout(V)
3940
Base.copy(A::MyVector) = MyVector(copy(A.A))
4041

42+
# These structs are separate since we would need to otherwise implement more Broadcast machinery for MyBroadcastStyle to get the tests below to not error
43+
struct MyBroadcastStyle{N} <: AbstractArrayStyle{N} end
44+
MyBroadcastStyle(::Val{N}) where N = MyBroadcastStyle{N}()
45+
MyBroadcastStyle{M}(::Val{N}) where {N,M} = MyBroadcastStyle{N}()
46+
struct MyMatrix2{T,M<:AbstractMatrix{T}} <: LayoutMatrix{T}
47+
A::M
48+
end
49+
Base.size(A::MyMatrix2) = size(A.A)
50+
BroadcastStyle(::Type{<:MyMatrix2{T}}) where {T} = MyBroadcastStyle{2}()
51+
struct MyVector2{T,V<:AbstractVector{T}} <: LayoutVector{T}
52+
A::V
53+
end
54+
Base.size(A::MyVector2) = size(A.A)
55+
BroadcastStyle(::Type{<:MyVector2{T}}) where {T} = MyBroadcastStyle{1}()
56+
4157
# These need to test dispatch reduces to ArrayLayouts.mul, etc.
4258
@testset "LayoutArray" begin
4359
@testset "LayoutVector" begin
@@ -69,6 +85,9 @@ Base.copy(A::MyVector) = MyVector(copy(A.A))
6985

7086
s = SparseVector(3, [1], [2])
7187
@test a's == s'a == dot(a,s) == dot(s,a) == dot(s,a.A)
88+
89+
a = MyVector2(a.A)
90+
@test BroadcastStyle(typeof(reshape(a, (1, 3)))) == MyBroadcastStyle{2}()
7291
end
7392

7493
@testset "LayoutMatrix" begin
@@ -279,6 +298,12 @@ Base.copy(A::MyVector) = MyVector(copy(A.A))
279298
D = Diagonal(MyVector(randn(5)))
280299
@test permutedims(D) D
281300
end
301+
302+
@testset "BroadcastStyle" begin
303+
A = MyMatrix2(randn(3,6))
304+
@test BroadcastStyle(typeof(reshape(A, (9, 2)))) == MyBroadcastStyle{2}()
305+
@test BroadcastStyle(typeof(vec(A))) == MyBroadcastStyle{2}()
306+
end
282307
end
283308

284309
@testset "l/rmul!" begin

0 commit comments

Comments
 (0)