Skip to content

Commit f1d9f43

Browse files
authored
Support ExpansionLayout in ContinuumArrays.jl (#80)
* Support ExpansionLayout in ContinuumArrays.jl * Update quasifill.jl * InclusionLayout, simplify layout_broadcasted * PolynomialLayout * Increase coverage * fix tests
1 parent f0d2a5e commit f1d9f43

11 files changed

+109
-23
lines changed

src/QuasiArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import Base: ones, zeros, one, zero, fill
3434

3535
import Base.Broadcast: materialize, materialize!, BroadcastStyle, AbstractArrayStyle, Style, broadcasted, Broadcasted, Unknown,
3636
newindex, broadcastable, preprocess, _eachindex, _broadcast_getindex, broadcast_shape,
37-
DefaultArrayStyle, axistype, throwdm, instantiate, combine_eltypes, eltypes
37+
DefaultArrayStyle, axistype, throwdm, instantiate, combine_eltypes, eltypes, combine_styles
3838

3939
import LinearAlgebra: transpose, adjoint, checkeltype_adjoint, checkeltype_transpose, Diagonal,
4040
AbstractTriangular, pinv, inv, promote_leaf_eltypes, power_by_squaring,

src/abstractquasiarray.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ function isequal(A::AbstractQuasiArray, B::AbstractQuasiArray)
504504
return true
505505
end
506506

507-
function (==)(A::AbstractQuasiArray, B::AbstractQuasiArray)
507+
function _equals(_, _, A, B)
508508
if axes(A) != axes(B)
509509
return false
510510
end
@@ -521,8 +521,13 @@ function (==)(A::AbstractQuasiArray, B::AbstractQuasiArray)
521521
end
522522

523523

524+
(==)(A::AbstractQuasiArray, B::AbstractQuasiArray) = _equals(MemoryLayout(A), MemoryLayout(B), A, B)
525+
524526
##
525527
# show
526528
##
527529

528-
show(io::IO, A::AbstractQuasiArray) = summary(io, A)
530+
show(io::IO, A::AbstractQuasiArray) = summary(io, A)
531+
532+
struct QuasiArrayLayout <: MemoryLayout end
533+
MemoryLayout(::Type{<:AbstractQuasiArray}) = QuasiArrayLayout()

src/indices.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ Inclusion(domain) = Inclusion{eltype(domain)}(domain)
142142
Inclusion(S::Inclusion) = S
143143
Inclusion(S::Slice) = Inclusion(S.indices)
144144

145+
struct PolynomialLayout <: MemoryLayout end
146+
147+
MemoryLayout(::Type{<:Inclusion}) = PolynomialLayout()
148+
145149
convert(::Type{Inclusion}, d::Inclusion) = d
146150
convert(::Type{Inclusion{T}}, d::Inclusion) where T = Inclusion{T}(d)
147151
convert(::Type{AbstractVector}, d::Inclusion{<:Any,<:AbstractVector}) =

src/lazyquasiarrays.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ subbroadcaststyle(::LazyQuasiArrayStyle{2}, ::Type{<:Tuple{Number,Number}}) = De
102102
subbroadcaststyle(::LazyQuasiArrayStyle{2}, ::Type{<:Tuple{Number,JR}}) where JR = Base.BroadcastStyle(JR)
103103
subbroadcaststyle(::LazyQuasiArrayStyle{2}, ::Type{<:Tuple{KR,Number}}) where KR = Base.BroadcastStyle(KR)
104104

105+
layout_broadcasted(_, f, args...) = Broadcasted{typeof(combine_styles(args...))}(f, args)
106+
107+
broadcasted(S::LazyQuasiArrayStyle, f, args...) = layout_broadcasted(map(MemoryLayout,args), f, args...)
108+
105109

106110
struct BroadcastQuasiArray{T, N, F, Args} <: LazyQuasiArray{T, N}
107111
f::F
@@ -137,7 +141,7 @@ broadcasted(A::BroadcastQuasiArray) = instantiate(broadcasted(A.f, A.args...))
137141
axes(A::BroadcastQuasiArray) = axes(broadcasted(A))
138142
size(A::BroadcastQuasiArray) = map(length, axes(A))
139143

140-
function ==(A::BroadcastQuasiArray, B::BroadcastQuasiArray)
144+
function _equals(::BroadcastLayout, ::BroadcastLayout, A, B)
141145
A.f == B.f && all(A.args .== B.args) && return true
142146
error("Not implemented")
143147
end
@@ -170,6 +174,9 @@ call(b::BroadcastLayout, a::SubQuasiArray) = call(b, parent(a))
170174
summary(io::IO, A::BroadcastQuasiArray) = _broadcastarray_summary(io, A)
171175
summary(io::IO, A::ApplyQuasiArray) = _applyarray_summary(io, A)
172176

177+
_mul_summary(_, io, A) = _applyarray_summary(io, A)
178+
summary(io::IO, A::ApplyQuasiArray{<:Any,N,typeof(*)}) where N = _mul_summary(MemoryLayout(A), io, A)
179+
173180
for op in (:+, :-, :*, :\, :/)
174181
@eval begin
175182
function summary(io::IO, A::BroadcastQuasiArray{<:Any,N,typeof($op)}) where N

src/matmul.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,6 @@ function copyto!(dest::MulQuasiArray, src::MulQuasiArray)
8686
dest
8787
end
8888

89-
90-
struct QuasiArrayLayout <: MemoryLayout end
91-
MemoryLayout(::Type{<:AbstractQuasiArray}) = QuasiArrayLayout()
9289
Array(M::Mul{<:Any,<:Any,<:Any,<:AbstractQuasiMatrix}) = eltype(M)[M[k,j] for k in axes(M)[1], j in axes(M)[2]]
9390
Array(M::Mul{<:Any,<:Any,<:Any,<:AbstractQuasiVector}) = eltype(M)[M[k] for k in axes(M)[1]]
9491
_quasi_mul(M, _) = QuasiArray(M)
@@ -132,8 +129,12 @@ _rdiv_scal_reduce(x::Number, Z) = (Z / x,)
132129
_rdiv_scal_reduce(x::Number, Z::AbstractArray, Y...) = (Y..., Z/x)
133130
_rdiv_scal_reduce(x::Number, Z, Y...) = (_rdiv_scal_reduce(x, Y...)..., Z)
134131

135-
*(x::Number, A::MulQuasiArray) = ApplyQuasiArray(*, _lmul_scal_reduce(x, arguments(A)...)...)
136-
*(A::MulQuasiArray, x::Number) = ApplyQuasiArray(*, _rmul_scal_reduce(x, reverse(arguments(A))...)...)
132+
broadcasted(::LazyQuasiArrayStyle{N}, ::typeof(*), x::Number, A::MulQuasiArray{<:Any,N}) where N =
133+
ApplyQuasiArray(*, _lmul_scal_reduce(x, arguments(A)...)...)
134+
broadcasted(::LazyQuasiArrayStyle{N}, ::typeof(*), A::MulQuasiArray{<:Any,N}, x::Number) where N =
135+
ApplyQuasiArray(*, _rmul_scal_reduce(x, reverse(arguments(A))...)...)
137136

138-
\(x::Number, A::MulQuasiArray) = ApplyQuasiArray(*, _ldiv_scal_reduce(x, arguments(A)...)...)
139-
/(A::MulQuasiArray, x::Number) = ApplyQuasiArray(*, _rdiv_scal_reduce(x, reverse(arguments(A))...)...)
137+
broadcasted(::LazyQuasiArrayStyle{N}, ::typeof(\), x::Number, A::MulQuasiArray{<:Any,N}) where N =
138+
ApplyQuasiArray(*, _ldiv_scal_reduce(x, arguments(A)...)...)
139+
broadcasted(::LazyQuasiArrayStyle{N}, ::typeof(/), A::MulQuasiArray{<:Any,N}, x::Number) where N =
140+
ApplyQuasiArray(*, _ldiv_scal_reduce(x, arguments(A)...)...)

src/quasifill.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ copy(F::QuasiFill) = QuasiFill(F.value, F.axes)
9898
convert(::Type{T}, F::T) where T<:QuasiFill = F
9999

100100

101-
102101
getindex(F::QuasiFill{<:Any,0}) = getindex_value(F)
103102

104103

@@ -191,6 +190,40 @@ for op in (:+, :-, :*, :/, :\)
191190
end
192191

193192

193+
########
194+
# show
195+
#######
196+
197+
function summary(io::IO, F::QuasiOnes)
198+
print(io, "ones(")
199+
summary(io, F.axes[1])
200+
for a in tail(F.axes)
201+
print(io, ", ")
202+
summary(io, a)
203+
end
204+
print(io, ")")
205+
end
206+
207+
function summary(io::IO, F::QuasiZeros)
208+
print(io, "zeros(")
209+
summary(io, F.axes[1])
210+
for a in tail(F.axes)
211+
print(io, ", ")
212+
summary(io, a)
213+
end
214+
print(io, ")")
215+
end
216+
217+
function summary(io::IO, F::QuasiFill)
218+
print(io, "fill($(F.value), ")
219+
summary(io, F.axes[1])
220+
for a in tail(F.axes)
221+
print(io, ", ")
222+
summary(io, a)
223+
end
224+
print(io, ")")
225+
end
226+
194227
#########
195228
# Special matrix types
196229
#########

src/quasireducedim.jl

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -266,19 +266,37 @@ end
266266
_sum(V::AbstractQuasiArray, dims) = __sum(MemoryLayout(V), V, dims)
267267
_sum(V::AbstractQuasiArray, ::Colon) = __sum(MemoryLayout(V), V, :)
268268

269-
# sum is equivalent to hitting by ones(n) on the left or rifght
270-
function __sum(LAY::ApplyLayout{typeof(*)}, V::AbstractQuasiMatrix, d::Int)
271-
a = arguments(LAY, V)
272-
if d == 1
273-
*(sum(first(a); dims=1), tail(a)...)
274-
else
275-
@assert d == 2
276-
*(most(a)..., sum(last(a); dims=2))
269+
_cumsum(A, dims) = __cumsum(MemoryLayout(A), A, dims)
270+
cumsum(A::AbstractQuasiArray; dims::Integer) = _cumsum(A, dims)
271+
cumsum(x::AbstractQuasiVector) = cumsum(x, dims=1)
272+
273+
# sum is equivalent to hitting by ones(n) on the left or right
274+
275+
__cumsum(::QuasiArrayLayout, A, ::Colon) = QuasiArray(cumsum(parent(A)), axes(A))
276+
__cumsum(::QuasiArrayLayout, A, d::Int) = QuasiArray(cumsum(parent(A),dims=d), axes(A))
277+
278+
for Sum in (:sum, :cumsum)
279+
__Sum = Symbol("__", Sum)
280+
@eval function $__Sum(LAY::ApplyLayout{typeof(*)}, V::AbstractQuasiMatrix, d::Int)
281+
a = arguments(LAY, V)
282+
if d == 1
283+
*($Sum(first(a); dims=1), tail(a)...)
284+
else
285+
@assert d == 2
286+
*(most(a)..., $Sum(last(a); dims=2))
287+
end
277288
end
278289
end
290+
291+
function __cumsum(LAY::ApplyLayout{typeof(*)}, V::AbstractQuasiVector, dims)
292+
a = arguments(LAY, V)
293+
apply(*, cumsum(a[1]; dims=dims), tail(a)...)
294+
end
295+
279296
function __sum(LAY::ApplyLayout{typeof(*)}, V::AbstractQuasiVector, ::Colon)
280297
a = arguments(LAY, V)
281298
first(apply(*, sum(a[1]; dims=1), tail(a)...))
282299
end
283300

284-
__sum(_, A, dims) = _sum(identity, A, dims)
301+
__sum(_, A, dims) = _sum(identity, A, dims)
302+

test/test_matmul.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import QuasiArrays: apply
77
@test rowsupport(A,0.1) == colsupport(A,0.1) == 0.1
88
b = Inclusion(0:0.1:1)
99
Ab = A*b
10-
@test Ab isa QuasiArray
1110
@test Ab[0.1] 0.1^2
1211
@test_throws DimensionMismatch A*Inclusion(1:2)
1312

test/test_quasifill.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,4 +704,10 @@ import QuasiArrays: AbstractQuasiFill
704704
@test_throws DimensionMismatch I+B
705705
@test B*I I*B B
706706
end
707+
708+
@testset "show" begin
709+
@test stringmime("text/plain",ones(Inclusion([1,2,3]))) == "ones(Inclusion([1, 2, 3]))"
710+
@test stringmime("text/plain",zeros(Inclusion([1,2,3]))) == "zeros(Inclusion([1, 2, 3]))"
711+
@test stringmime("text/plain",fill(2,Inclusion([1,2,3]))) == "fill(2, Inclusion([1, 2, 3]))"
712+
end
707713
end

test/test_quasilazy.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using QuasiArrays, LazyArrays, ArrayLayouts, Test
1+
using QuasiArrays, LazyArrays, ArrayLayouts, Base64, Test
22
import QuasiArrays: QuasiLazyLayout, QuasiArrayApplyStyle, LazyQuasiMatrix, LazyQuasiArrayStyle
33
import LazyArrays: MulStyle, ApplyStyle
44

@@ -78,6 +78,11 @@ Base.getindex(A::MyQuasiLazyMatrix, x::Float64, y::Float64) = A.A[x,y]
7878

7979
@test BroadcastQuasiArray(*, x, ApplyQuasiArray(^, A, 2)) * y (x .* A^2) * y
8080
end
81+
82+
@testset "summary" begin
83+
A = ApplyQuasiArray(*, ones(Inclusion([1,2,3]), Inclusion([4,5])), fill(2,Inclusion([4,5])))
84+
@test stringmime("text/plain", A) == "(ones(Inclusion([1, 2, 3]), Inclusion([4, 5]))) * (fill(2, Inclusion([4, 5])))"
85+
end
8186
end
8287
@testset "\\" begin
8388
A = QuasiArray(rand(3,3),(0:0.5:1,0:0.5:1))

0 commit comments

Comments
 (0)