Skip to content

Commit 631e150

Browse files
authored
Add mul/colsupport for MulQuasiArray, support overloading dot and sum (#94)
* Add mul/colsupport for MulQuasiArray * multiplication with QuasiArrays and Zeros * Improve show * increase coverage * Fix test * fix tests * sum -> sum_layout
1 parent 6362a48 commit 631e150

15 files changed

+89
-34
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ jobs:
1212
version:
1313
- '1.6'
1414
- '1'
15-
- '^1.9.0-0'
1615
os:
1716
- ubuntu-latest
1817
- macOS-latest

Project.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "QuasiArrays"
22
uuid = "c4ea9172-b204-11e9-377d-29865faadc5c"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "0.9.8"
4+
version = "0.10"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
@@ -12,10 +12,10 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1313

1414
[compat]
15-
ArrayLayouts = "0.7.6, 0.8, 1"
16-
DomainSets = "0.5, 0.6"
17-
FillArrays = "0.12, 0.13, 1"
18-
LazyArrays = "0.22.2, 1"
15+
ArrayLayouts = "1"
16+
DomainSets = "0.6"
17+
FillArrays = "1"
18+
LazyArrays = "1.2"
1919
StaticArrays = "1"
2020
julia = "1.6"
2121

src/QuasiArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ import LazyArrays: MemoryLayout, UnknownLayout, Mul, ApplyLayout, BroadcastLayou
4949
LdivStyle, InvLayout, PInvLayout, sub_materialize, lazymaterialize,
5050
_mul, rowsupport, DiagonalLayout, adjointlayout, transposelayout, conjlayout,
5151
sublayout, call, LazyArrayStyle, layout_getindex, _broadcast2broadcastarray, _applyarray_summary, _broadcastarray_summary,
52-
_broadcasted_mul, simplifiable, simplify
52+
_broadcasted_mul, simplifiable, simplify, _mul_colsupport, _mul_rowsupport
5353

5454
import Base.IteratorsMD
5555

src/abstractquasiarray.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -523,8 +523,5 @@ end
523523
##
524524
# show
525525
##
526-
527-
show(io::IO, A::AbstractQuasiArray) = summary(io, A)
528-
529526
struct QuasiArrayLayout <: MemoryLayout end
530527
MemoryLayout(::Type{<:AbstractQuasiArray}) = QuasiArrayLayout()

src/indices.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ getindex(S::Inclusion{T}, i::AbstractArray{T}) where T = (@_inline_meta; @bounds
194194
getindex(S::Inclusion, i::Inclusion) = (@_inline_meta; @boundscheck checkbounds(S, i); copy(S))
195195
getindex(S::Inclusion, ::Colon) = copy(S)
196196
Base.unsafe_getindex(S::Inclusion{T}, x) where T = convert(T, x)::T
197+
show(io::IO, r::Inclusion) = summary(io, r)
197198
summary(io::IO, r::Inclusion) = print(io, "Inclusion(", r.domain, ")")
198199
iterate(S::Inclusion, s...) = iterate(S.domain, s...)
199200

src/lazyquasiarrays.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,13 +173,17 @@ call(b::BroadcastLayout, a::SubQuasiArray) = call(b, parent(a))
173173
# show
174174
####
175175

176+
show(io::IO, A::BroadcastQuasiArray) = summary(io, A)
177+
show(io::IO, A::ApplyQuasiArray) = summary(io, A)
176178

177179
summary(io::IO, A::BroadcastQuasiArray) = _broadcastarray_summary(io, A)
178180
summary(io::IO, A::ApplyQuasiArray) = _applyarray_summary(io, A)
179181

180182
_mul_summary(_, io, A) = _applyarray_summary(io, A)
181183
summary(io::IO, A::ApplyQuasiArray{<:Any,N,typeof(*)}) where N = _mul_summary(MemoryLayout(A), io, A)
182184

185+
186+
183187
for op in (:+, :-, :*, :\, :/)
184188
@eval begin
185189
function summary(io::IO, A::BroadcastQuasiArray{<:Any,N,typeof($op)}) where N

src/matmul.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ _factors(M) = (M,)
6464
@inline flatten(A::MulQuasiArray) = ApplyQuasiArray(flatten(Applied(A)))
6565
@inline flatten(A::SubQuasiArray{<:Any,2,<:MulQuasiArray}) = materialize(flatten(Applied(A)))
6666

67-
67+
colsupport(B::MulQuasiArray, j) = _mul_colsupport(j, reverse(B.args)...)
68+
rowsupport(B::MulQuasiArray, j) = _mul_rowsupport(j, B.args...)
69+
_mul_colsupport(j, Z::AbstractQuasiArray) = colsupport(Z,j)
70+
_mul_colsupport(j, Z::AbstractQuasiArray, Y...) = _mul_colsupport(colsupport(Z,j), Y...)
6871

6972
adjoint(A::MulQuasiArray) = ApplyQuasiArray(*, reverse(adjoint.(A.args))...)
7073
transpose(A::MulQuasiArray) = ApplyQuasiArray(*, reverse(transpose.(A.args))...)

src/quasiadjtrans.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,15 +167,10 @@ map(f, tvs::QuasiTransposeAbsVec...) = transpose(map((xs...) -> transpose(f(tran
167167
## multiplication *
168168

169169
# QuasiAdjoint/QuasiTranspose-vector * vector
170+
dot(a::AbstractQuasiArray, b::AbstractQuasiArray) = ArrayLayouts.dot(a, b)
171+
@inline copy(d::Dot{<:Any,<:Any,<:AbstractQuasiArray,<:AbstractQuasiArray}) = _dot(size(d.A,1), d.A, d.B)
172+
_dot(sz, a, b) = Base.invoke(dot, NTuple{2,Any}, a, b)
170173

171-
# allow re-expansion using Legendre
172-
_dot(ax, a, b) = Base.invoke(dot, NTuple{2,Any}, a, b)
173-
174-
function dot(a::AbstractQuasiArray, b::AbstractQuasiArray)
175-
ax = axes(a,1)
176-
ax == axes(b,1) || throw(DimensionMismatch())
177-
_dot(ax, a, b)
178-
end
179174

180175
*(u::QuasiAdjointAbsVec, v::AbstractQuasiVector) = dot(u.parent, v)
181176
*(u::QuasiTransposeAbsVec{T}, v::AbstractQuasiVector{T}) where {T<:Real} = dot(u.parent, v)

src/quasiarray.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,19 @@ function reshape(A::QuasiVector, ax::Tuple{Any,OneTo{Int}})
149149
@assert ax == (axes(A,1),Base.OneTo(1))
150150
QuasiMatrix(reshape(A.parent,size(A.parent,1),1), (A.axes[1], Base.OneTo(1)))
151151
end
152+
153+
function show(io::IO, A::QuasiVector)
154+
print(io, "QuasiVector(")
155+
show(io, A.parent)
156+
print(io, ", ")
157+
show(io, A.axes[1])
158+
print(io, ")")
159+
end
160+
161+
function show(io::IO, A::QuasiMatrix)
162+
print(io, "QuasiMatrix(")
163+
show(io, A.parent)
164+
print(io, ", ")
165+
show(io, A.axes)
166+
print(io, ")")
167+
end

src/quasifill.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ end
194194
# show
195195
#######
196196

197+
show(io::IO, F::AbstractQuasiFill) = summary(io, F)
198+
197199
function summary(io::IO, F::QuasiOnes)
198200
print(io, "ones(")
199201
summary(io, F.axes[1])
@@ -488,7 +490,7 @@ MemoryLayout(::Type{<:QuasiZeros}) = ZerosLayout()
488490
MemoryLayout(::Type{<:QuasiOnes}) = OnesLayout()
489491

490492
_quasi_mul(M::Mul{ZerosLayout}, _) = QuasiZeros{eltype(M)}(axes(M))
491-
_quasi_mul(M::Mul{QuasiArrayLayout,ZerosLayout}, _) = QuasiZeros{eltype(M)}(axes(M))
493+
_quasi_mul(M::Mul{QuasiArrayLayout,ZerosLayout}, _) = FillArrays.mult_zeros(M.A, M.B)
492494
_quasi_mul(M::Mul{QuasiArrayLayout,ZerosLayout}, ::NTuple{N,OneTo{Int}}) where N = Zeros{eltype(M)}(axes(M))
493495
fillzeros(::Type{T}, a::Tuple{AbstractQuasiVector,Vararg{Any}}) where T<:Number = QuasiZeros{T}(a)
494496
fillzeros(::Type{T}, a::Tuple{Any,AbstractQuasiVector,Vararg{Any}}) where T<:Number = QuasiZeros{T}(a)
@@ -520,4 +522,12 @@ fill(c, x::Inclusion, y::Union{OneTo,IdentityUnitRange,Inclusion}...) = QuasiFil
520522
fill(c, x::Union{OneTo,IdentityUnitRange}, y::Inclusion, z::Union{OneTo,IdentityUnitRange,Inclusion}...) = QuasiFill(c, (x, y, z...))
521523

522524
iszero(x::AbstractQuasiFill) = iszero(getindex_value(x))
523-
isone(x::AbstractQuasiFill) = isone(getindex_value(x))
525+
isone(x::AbstractQuasiFill) = isone(getindex_value(x))
526+
527+
528+
function FillArrays.mult_zeros(a::AbstractQuasiArray, b)
529+
axes(a, 2) axes(b, 1) &&
530+
throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
531+
T = promote_type(eltype(a), eltype(b))
532+
fillsimilar(QuasiZeros{T}(), axes(a, 1), axes(b)[2:end]...)
533+
end

0 commit comments

Comments
 (0)