Skip to content

Commit 74dc23f

Browse files
committed
Improve operators
1 parent d4641e0 commit 74dc23f

File tree

2 files changed

+64
-45
lines changed

2 files changed

+64
-45
lines changed

src/bases/bases.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -662,8 +662,8 @@ cumsum_layout(::ExpansionLayout, A, dims) = cumsum_layout(ApplyLayout{typeof(*)}
662662
###
663663
# diff
664664
###
665-
diff_layout(::AbstractBasisLayout, Vm; dims...) = error("Overload diff(::$(typeof(Vm)))")
666-
function diff_layout(::AbstractBasisLayout, a, order; dims...)
665+
diff_layout(::AbstractBasisLayout, Vm, order...; dims...) = error("Overload diff(::$(typeof(Vm)))")
666+
function diff_layout(::AbstractBasisLayout, a, order::Int; dims...)
667667
order < 0 && throw(ArgumentError("order must be non-negative"))
668668
order == 0 && return a
669669
isone(order) ? diff(a) : diff(diff(a), order-1)
@@ -730,9 +730,11 @@ abslaplacian_axis(::Inclusion{<:Number}, A, order=1; dims...) = -diff(A, 2order;
730730

731731
laplacian(A, order...; dims...) = laplacian_layout(MemoryLayout(A), A, order...; dims...)
732732
laplacian_layout(layout, A, order...; dims...) = laplacian_axis(axes(A,1), A, order...; dims...)
733-
laplacian_axis(::Inclusion{<:Number}, A, order...; dims...) = -abslaplacian(A, order...)
733+
laplacian_axis(::Inclusion{<:Number}, A, order=1; dims...) = diff(A, 2order; dims...)
734734

735735

736+
laplacian_layout(::ExpansionLayout, A, order...; dims...) = laplacian_layout(ApplyLayout{typeof(*)}(), A, order...; dims...)
737+
abslaplacian_layout(::ExpansionLayout, A, order...; dims...) = abslaplacian_layout(ApplyLayout{typeof(*)}(), A, order...; dims...)
736738

737739
function abslaplacian_layout(::SubBasisLayout, Vm, order...; dims::Integer=1)
738740
dims == 1 || error("not implemented")
@@ -744,6 +746,19 @@ function laplacian_layout(::SubBasisLayout, Vm, order...; dims::Integer=1)
744746
laplacian(parent(Vm), order...)[:,parentindices(Vm)[2]]
745747
end
746748

749+
function laplacian_layout(LAY::ApplyLayout{typeof(*)}, V::AbstractQuasiVecOrMat, order...; dims=1)
750+
a = arguments(LAY, V)
751+
dims == 1 || throw(ArgumentError("cannot take laplacian a vector along dimension $dims"))
752+
*(laplacian(a[1], order...), tail(a)...)
753+
end
754+
755+
function abslaplacian_layout(LAY::ApplyLayout{typeof(*)}, V::AbstractQuasiVecOrMat, order...; dims=1)
756+
a = arguments(LAY, V)
757+
dims == 1 || throw(ArgumentError("cannot take abslaplacian a vector along dimension $dims"))
758+
*(abslaplacian(a[1], order...), tail(a)...)
759+
end
760+
761+
747762

748763
"""
749764
weaklaplacian(A)

src/operators.jl

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,15 @@ show(io::IO, ::MIME"text/plain", δ::DiracDelta) = show(io, δ)
108108
# Differentiation
109109
#########
110110

111+
abstract type AbstractDifferentialQuasiMatrix{T} <: LazyQuasiMatrix{T} end
112+
111113
"""
112114
Derivative(axis)
113115
114116
represents the differentiation (or finite-differences) operator on the
115117
specified axis.
116118
"""
117-
struct Derivative{T,D<:Inclusion,Order} <: LazyQuasiMatrix{T}
119+
struct Derivative{T,D<:Inclusion,Order} <: AbstractDifferentialQuasiMatrix{T}
118120
axis::D
119121
order::Order
120122
end
@@ -125,7 +127,7 @@ Laplacian(axis)
125127
represents the laplacian operator `Δ` on the
126128
specified axis.
127129
"""
128-
struct Laplacian{T,D<:Inclusion,Order} <: LazyQuasiMatrix{T}
130+
struct Laplacian{T,D<:Inclusion,Order} <: AbstractDifferentialQuasiMatrix{T}
129131
axis::D
130132
order::Order
131133
end
@@ -136,14 +138,42 @@ AbsLaplacian(axis)
136138
represents the positive-definite/negative/absolute-value laplacian operator `|Δ| ≡ -Δ` on the
137139
specified axis.
138140
"""
139-
struct AbsLaplacian{T,D<:Inclusion,Order} <: LazyQuasiMatrix{T}
141+
struct AbsLaplacian{T,D<:Inclusion,Order} <: AbstractDifferentialQuasiMatrix{T}
140142
axis::D
141143
order::Order
142144
end
143145

144-
_operatororder(D) = something(D.order, 1)
146+
operatororder(D) = something(D.order, 1)
147+
148+
show(io::IO, a::AbstractDifferentialQuasiMatrix) = summary(io, a)
149+
axes(D::AbstractDifferentialQuasiMatrix) = (D.axis, D.axis)
150+
==(a::AbstractDifferentialQuasiMatrix, b::AbstractDifferentialQuasiMatrix) = a.axis == b.axis && operatororder(a) == operatororder(b)
151+
copy(D::AbstractDifferentialQuasiMatrix) = D
152+
153+
154+
155+
@simplify function *(D::AbstractDifferentialQuasiMatrix, B::AbstractQuasiVecOrMat)
156+
T = typeof(zero(eltype(D)) * zero(eltype(B)))
157+
operatorcall(D, convert(AbstractQuasiArray{T}, B), D.order)
158+
end
159+
160+
^(D::AbstractDifferentialQuasiMatrix{T}, k::Integer) where T = similaroperator(D, D.axis, operatororder(D) .* k)
161+
162+
function view(D::AbstractDifferentialQuasiMatrix, kr::Inclusion, jr::Inclusion)
163+
@boundscheck axes(D,1) == kr == jr || throw(BoundsError(D,(kr,jr)))
164+
D
165+
end
166+
167+
operatorcall(D::AbstractDifferentialQuasiMatrix, B, order) = operatorcall(D)(B, order)
168+
operatorcall(D::AbstractDifferentialQuasiMatrix, B, ::Nothing) = operatorcall(D)(B)
169+
170+
171+
operatorcall(::Derivative) = diff
172+
operatorcall(::Laplacian) = laplacian
173+
operatorcall(::AbsLaplacian) = abslaplacian
145174

146-
for (Op, op) in ((:Derivative, :diff), (:Laplacian, :laplacian), (:AbsLaplacian, :abslaplacian))
175+
176+
for Op in (:Derivative, :Laplacian, :AbsLaplacian)
147177
@eval begin
148178
$Op{T, D}(axis::D, order) where {T,D<:Inclusion} = $Op{T,D,typeof(order)}(axis, order)
149179
$Op{T, D}(axis::D) where {T,D<:Inclusion} = $Op{T,D,Nothing}(axis, nothing)
@@ -153,42 +183,20 @@ for (Op, op) in ((:Derivative, :diff), (:Laplacian, :laplacian), (:AbsLaplacian,
153183
$Op(ax::AbstractQuasiVector{T}, order...) where T = $Op{eltype(eltype(ax))}(ax, order...)
154184
$Op(L::AbstractQuasiMatrix, order...) = $Op(axes(L,1), order...)
155185

156-
show(io::IO, a::$Op) = summary(io, a)
157-
function summary(io::IO, D::$Op{<:Any,<:Inclusion,Nothing})
158-
print(io, "$($Op)(")
159-
summary(io, D.axis)
160-
print(io,")")
161-
end
186+
similaroperator(D::$Op, ax, ord) = $Op{eltype(D)}(ax, ord)
187+
188+
simplifiable(::typeof(*), A::$Op, B::$Op) = Val(true)
189+
*(D1::$Op, D2::$Op) = similaroperator(convert(AbstractQuasiMatrix{promote_type(eltype(D1),eltype(D2))}, D1), D1.axis, operatororder(D1)+operatororder(D2))
190+
162191

163192
function summary(io::IO, D::$Op)
164193
print(io, "$($Op)(")
165194
summary(io, D.axis)
166-
print(io, ", ")
167-
print(io, D.order)
168-
print(io,")")
169-
end
170-
171-
axes(D::$Op) = (D.axis, D.axis)
172-
==(a::$Op, b::$Op) = a.axis == b.axis && _operatororder(a) == _operatororder(b)
173-
copy(D::$Op) = D
174-
175-
176-
@simplify function *(D::$Op, B::AbstractQuasiVecOrMat)
177-
T = typeof(zero(eltype(D)) * zero(eltype(B)))
178-
if D.order isa Nothing
179-
$op(convert(AbstractQuasiArray{T}, B))
180-
else
181-
$op(convert(AbstractQuasiArray{T}, B), D.order)
195+
if !isnothing(D.order)
196+
print(io, ", ")
197+
print(io, D.order)
182198
end
183-
end
184-
185-
^(D::$Op{T}, k::Integer) where T = $Op{T}(D.axis, _operatororder(D) .* k)
186-
187-
@simplify *(D1::$Op, D2::$Op) = $Op{promote_type(eltype(D1),eltype(D2))}(D1.axis, _operatororder(D1)+_operatororder(D2))
188-
189-
function view(D::$Op, kr::Inclusion, jr::Inclusion)
190-
@boundscheck axes(D,1) == kr == jr || throw(BoundsError(D,(kr,jr)))
191-
D
199+
print(io,")")
192200
end
193201
end
194202
end
@@ -199,12 +207,8 @@ end
199207
# end
200208

201209

202-
const Identity{T,D} = QuasiDiagonal{T,Inclusion{T,D}}
203-
204-
Identity(d::Inclusion) = QuasiDiagonal(d)
205-
206210
struct OperatorLayout <: AbstractLazyLayout end
207-
MemoryLayout(::Type{<:Derivative}) = OperatorLayout()
211+
MemoryLayout(::Type{<:AbstractDifferentialQuasiMatrix}) = OperatorLayout()
208212
# copy(M::Mul{OperatorLayout, <:ExpansionLayout}) = simplify(M)
209213
# copy(M::Mul{OperatorLayout, <:AbstractLazyLayout}) = M.A * expand(M.B)
210214

@@ -215,4 +219,4 @@ abs(Δ::Laplacian{T}) where T = AbsLaplacian{T}(axes(Δ,1), Δ.order)
215219
-::Laplacian{<:Any,<:Any,Nothing}) = abs(Δ)
216220
-::AbsLaplacian{T,<:Any,Nothing}) where T = Laplacian{T}.axis)
217221

218-
^::AbsLaplacian{T}, k::Real) where T = AbsLaplacian{T}.axis, _operatororder(Δ)*k)
222+
^::AbsLaplacian{T}, k::Real) where T = AbsLaplacian{T}.axis, operatororder(Δ)*k)

0 commit comments

Comments
 (0)