Skip to content

Commit d6f6186

Browse files
authored
Merge pull request #184 from vpuri3/fixes
Fixes
2 parents 78285c8 + fc4bf13 commit d6f6186

File tree

6 files changed

+99
-33
lines changed

6 files changed

+99
-33
lines changed

src/basic.jl

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ for op in (
8282
end
8383
end
8484

85-
function Base.:\(::IdentityOperator, A::AbstractSciMLOperator)
85+
function Base.:\(ii::IdentityOperator, A::AbstractSciMLOperator)
8686
@assert size(A, 1) == ii.len
8787
A
8888
end
8989

90-
function Base.:/(A::AbstractSciMLOperator, ::IdentityOperator)
90+
function Base.:/(A::AbstractSciMLOperator, ii::IdentityOperator)
9191
@assert size(A, 2) == ii.len
9292
A
9393
end
@@ -330,8 +330,9 @@ AddedOperator(L::AbstractSciMLOperator) = L
330330
# constructors
331331
Base.:+(A::AbstractSciMLOperator, B::AbstractMatrix) = A + MatrixOperator(B)
332332
Base.:+(A::AbstractMatrix, B::AbstractSciMLOperator) = MatrixOperator(A) + B
333-
Base.:+(ops::AbstractSciMLOperator...) = AddedOperator(ops...)
334333

334+
Base.:+(ops::AbstractSciMLOperator...) = reduce(+, ops)
335+
Base.:+(A::AbstractSciMLOperator, B::AbstractSciMLOperator) = AddedOperator(A, B)
335336
Base.:+(A::AbstractSciMLOperator, B::AddedOperator) = AddedOperator(A, B.ops...)
336337
Base.:+(A::AddedOperator, B::AbstractSciMLOperator) = AddedOperator(A.ops..., B)
337338
Base.:+(A::AddedOperator, B::AddedOperator) = AddedOperator(A.ops..., B.ops...)
@@ -471,16 +472,15 @@ function ComposedOperator(ops::AbstractSciMLOperator...; cache = nothing)
471472
end
472473

473474
# constructors
474-
Base.:(ops::AbstractSciMLOperator...) = ComposedOperator(ops...)
475-
Base.:(A::ComposedOperator, B::ComposedOperator) = ComposedOperator(A.ops..., B.ops...)
476-
Base.:(A::AbstractSciMLOperator, B::ComposedOperator) = ComposedOperator(A, B.ops...)
477-
Base.:(A::ComposedOperator, B::AbstractSciMLOperator) = ComposedOperator(A.ops..., B)
478-
479-
Base.:*(ops::AbstractSciMLOperator...) = ComposedOperator(ops...)
480-
Base.:*(A::AbstractSciMLOperator, B::AbstractSciMLOperator) = (A, B)
481-
Base.:*(A::ComposedOperator, B::AbstractSciMLOperator) = (A.ops[1:end-1]..., A.ops[end] * B)
482-
Base.:*(A::AbstractSciMLOperator, B::ComposedOperator) = (A * B.ops[1], B.ops[2:end]...)
483-
Base.:*(A::ComposedOperator, B::ComposedOperator) = ComposedOperator(A.ops..., B.ops...)
475+
for op in (
476+
:*, :,
477+
)
478+
@eval Base.$op(ops::AbstractSciMLOperator...) = reduce($op, ops)
479+
@eval Base.$op(A::AbstractSciMLOperator, B::AbstractSciMLOperator) = ComposedOperator(A, B)
480+
@eval Base.$op(A::ComposedOperator, B::AbstractSciMLOperator) = ComposedOperator(A.ops..., B)
481+
@eval Base.$op(A::AbstractSciMLOperator, B::ComposedOperator) = ComposedOperator(A, B.ops...)
482+
@eval Base.$op(A::ComposedOperator, B::ComposedOperator) = ComposedOperator(A.ops..., B.ops...)
483+
end
484484

485485
for op in (
486486
:*, :,
@@ -606,11 +606,20 @@ function cache_self(L::ComposedOperator, u::AbstractVecOrMat)
606606
K = size(u, 2)
607607
cache = (zero(u),)
608608
for i in reverse(2:length(L.ops))
609+
op = L.ops[i]
609610

610-
M = size(L.ops[i], 1)
611-
T = promote_type(eltype.((L.ops[i], cache[1]))...)
611+
M = size(op, 1)
612612
sz = u isa AbstractMatrix ? (M, K) : (M,)
613613

614+
T = if op isa FunctionOperator #
615+
# FunctionOperator isn't guaranteed to play by the rules of
616+
# `promote_type`. For example, an rFFT is a complex operation
617+
# that accepts and complex vector and returns a real one.
618+
op.traits.eltypes[2]
619+
else
620+
promote_type(eltype.((op, cache[1]))...)
621+
end
622+
614623
cache = (similar(u, T, sz), cache...)
615624
end
616625

@@ -623,12 +632,12 @@ function cache_internals(L::ComposedOperator, u::AbstractVecOrMat)
623632
L = cache_self(L, u)
624633
end
625634

626-
vecs = L.cache
635+
ops = ()
627636
for i in reverse(1:length(L.ops))
628-
@set! L.ops[i] = cache_operator(L.ops[i], vecs[i])
637+
ops = (cache_operator(L.ops[i], L.cache[i]), ops...)
629638
end
630639

631-
L
640+
@set! L.ops = ops
632641
end
633642

634643
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::ComposedOperator, u::AbstractVecOrMat)

src/func.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,11 @@ end
8686

8787
function FunctionOperator(op,
8888
input::AbstractVecOrMat,
89-
output::AbstractVecOrMat = input;
89+
output::AbstractVecOrMat = input;
9090

9191
isinplace::Union{Nothing,Bool}=nothing,
9292
outofplace::Union{Nothing,Bool}=nothing,
93+
isconstant::Bool = false,
9394
has_mul5::Union{Nothing,Bool}=nothing,
9495
cache::Union{Nothing, NTuple{2}}=nothing,
9596
T::Union{Type{<:Number},Nothing}=nothing,
@@ -112,8 +113,10 @@ function FunctionOperator(op,
112113
isposdef::Bool = false,
113114
)
114115

116+
# store eltype of input/output for caching with ComposedOperator.
117+
eltypes = eltype.((input, output))
115118
sz = (size(output, 1), size(input, 1))
116-
T = isnothing(T) ? promote_type(eltype.((input, output))...) : T
119+
T = isnothing(T) ? promote_type(eltypes...) : T
117120
t = isnothing(t) ? zero(real(T)) : t
118121

119122
isinplace = if isnothing(isinplace)
@@ -164,6 +167,7 @@ function FunctionOperator(op,
164167

165168
traits = (;
166169
islinear = islinear,
170+
isconstant = isconstant,
167171

168172
opnorm = opnorm,
169173
issymmetric = issymmetric,
@@ -176,6 +180,7 @@ function FunctionOperator(op,
176180
ifcache = ifcache,
177181
T = T,
178182
size = sz,
183+
eltypes = eltypes,
179184
)
180185

181186
L = FunctionOperator(
@@ -197,6 +202,11 @@ function FunctionOperator(op,
197202
end
198203

199204
function update_coefficients(L::FunctionOperator, u, p, t)
205+
206+
if isconstant(L)
207+
return L
208+
end
209+
200210
@set! L.op = update_coefficients(L.op, u, p, t)
201211
@set! L.op_adjoint = update_coefficients(L.op_adjoint, u, p, t)
202212
@set! L.op_inverse = update_coefficients(L.op_inverse, u, p, t)
@@ -209,6 +219,11 @@ function update_coefficients(L::FunctionOperator, u, p, t)
209219
end
210220

211221
function update_coefficients!(L::FunctionOperator, u, p, t)
222+
223+
if isconstant(L)
224+
return L
225+
end
226+
212227
for op in getops(L)
213228
update_coefficients!(op, u, p, t)
214229
end
@@ -250,6 +265,7 @@ function Base.adjoint(L::FunctionOperator)
250265

251266
traits = L.traits
252267
@set! traits.size = reverse(size(L))
268+
@set! traits.eltypes = reverse(traits.eltypes)
253269

254270
p = L.p
255271
t = L.t
@@ -284,6 +300,7 @@ function Base.inv(L::FunctionOperator)
284300

285301
traits = L.traits
286302
@set! traits.size = reverse(size(L))
303+
@set! traits.eltypes = reverse(traits.eltypes)
287304

288305
@set! traits.opnorm = if traits.opnorm isa Number
289306
1 / traits.opnorm
@@ -353,8 +370,8 @@ function getops(L::FunctionOperator)
353370
ops
354371
end
355372

356-
#TODO - isconstant(L::FunctionOperator)
357373
islinear(L::FunctionOperator) = L.traits.islinear
374+
isconstant(L::FunctionOperator) = L.traits.isconstant
358375
has_adjoint(L::FunctionOperator) = !(L.op_adjoint isa Nothing)
359376
has_mul(L::FunctionOperator{iip}) where{iip} = true
360377
has_mul!(L::FunctionOperator{iip}) where{iip} = iip

src/scalar.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,11 @@ end
214214
for op in (
215215
:*, :,
216216
)
217-
@eval Base.$op(ops::AbstractSciMLScalarOperator...) = ComposedScalarOperator(ops...)
218-
@eval Base.$op(A::ComposedScalarOperator, B::ComposedScalarOperator) = ComposedScalarOperator(A.ops..., B.ops...)
219-
@eval Base.$op(A::AbstractSciMLScalarOperator, B::ComposedScalarOperator) = ComposedScalarOperator(A, B.ops...)
217+
@eval Base.$op(ops::AbstractSciMLScalarOperator...) = reduce($op, ops)
218+
@eval Base.$op(A::AbstractSciMLScalarOperator, B::AbstractSciMLScalarOperator) = ComposedScalarOperator(A, B)
220219
@eval Base.$op(A::ComposedScalarOperator, B::AbstractSciMLScalarOperator) = ComposedScalarOperator(A.ops..., B)
220+
@eval Base.$op(A::AbstractSciMLScalarOperator, B::ComposedScalarOperator) = ComposedScalarOperator(A, B.ops...)
221+
@eval Base.$op(A::ComposedScalarOperator, B::ComposedScalarOperator) = ComposedScalarOperator(A.ops..., B.ops...)
221222

222223
for T in SCALINGNUMBERTYPES[2:end]
223224
@eval Base.$op::AbstractSciMLScalarOperator, x::$T) = ComposedScalarOperator(α, ScalarOperator(x))

test/basic.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,15 @@ end
171171

172172
v=rand(N,K); @test mul!(v, op, u) (A+B) * u
173173
v=rand(N,K); w=copy(v); @test mul!(v, op, u, α, β) α*(A+B)*u + β*w
174+
175+
# ensure AddedOperator doesn't nest
176+
A = MatrixOperator(rand(N, N))
177+
L = A + (A + A) + A
178+
@test L isa AddedOperator
179+
for op in L.ops
180+
@test !isa(op, AddedOperator)
181+
end
182+
174183
end
175184

176185
@testset "ComposedOperator" begin
@@ -221,6 +230,14 @@ end
221230
v=rand(N,K); @test ldiv!(v, op, u) (A * B * C) \ u
222231
v=copy(u); @test ldiv!(op, u) (A * B * C) \ v
223232

233+
# ensure composedoperators doesn't nest
234+
A = MatrixOperator(rand(N, N))
235+
L = A * (A * A) * A
236+
@test L isa ComposedOperator
237+
for op in L.ops
238+
@test !isa(op, ComposedOperator)
239+
end
240+
224241
# Test caching of composed operator when inner ops do not support Base.:*
225242
# ComposedOperator caching was modified in PR # 174
226243
inner_op = qr(MatrixOperator(rand(N, N)))

test/scalar.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
#
2-
using SciMLOperators, LinearAlgebra
3-
using Random
2+
using SciMLOperators
3+
using SciMLOperators: AbstractSciMLScalarOperator,
4+
ComposedScalarOperator,
5+
AddedScalarOperator,
6+
InvertedScalarOperator,
7+
IdentityOperator,
8+
AddedOperator,
9+
ScaledOperator
10+
11+
using LinearAlgebra, Random
412

513
Random.seed!(0)
614
N = 8
@@ -38,31 +46,31 @@ K = 12
3846

3947
# Test that ScalarOperator's remain AbstractSciMLScalarOperator's under common ops
4048
β = α + α
41-
@test β isa SciMLOperators.AddedScalarOperator
49+
@test β isa AddedScalarOperator
4250
@test β * u x * u + x * u
4351
@inferred convert(Float32, β)
4452
@test convert(Number, β) x + x
4553

4654
β = α * α
47-
@test β isa SciMLOperators.ComposedScalarOperator
55+
@test β isa ComposedScalarOperator
4856
@test β * u x * x * u
4957
@inferred convert(Float32, β)
5058
@test convert(Number, β) x * x
5159

5260
β = inv(α)
53-
@test β isa SciMLOperators.InvertedScalarOperator
61+
@test β isa InvertedScalarOperator
5462
@test β * u 1 / x * u
5563
@inferred convert(Float32, β)
5664
@test convert(Number, β) 1 / x
5765

5866
β = α * inv(α)
59-
@test β isa SciMLOperators.ComposedScalarOperator
67+
@test β isa ComposedScalarOperator
6068
@test β * u u
6169
@inferred convert(Float32, β)
6270
@test convert(Number, β) true
6371

6472
β = α / α
65-
@test β isa SciMLOperators.ComposedScalarOperator
73+
@test β isa ComposedScalarOperator
6674
@test β * u u
6775
@inferred convert(Float32, β)
6876
@test convert(Number, β) true
@@ -77,6 +85,14 @@ K = 12
7785
@test/ op) * u (op \ α) * u α * (op \ u)
7886
@test (op / α) * u \ op) * u 1/α * op * u
7987
end
88+
89+
# ensure composedscalaroperators doesn't nest
90+
α = ScalarOperator(rand())
91+
L = α ** α) * α
92+
@test L isa ComposedScalarOperator
93+
for op in L.ops
94+
@test !isa(op, ComposedScalarOperator)
95+
end
8096
end
8197

8298
@testset "ScalarOperator update test" begin

test/total.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,17 @@ K = 12
3535
ik = im * DiagonalOperator(k)
3636
Dx = ftr \ ik * ftr
3737
Dx = cache_operator(Dx, x)
38+
D2x = cache_operator(Dx * Dx, x)
3839

39-
u = @. sin(5x)cos(7x);
40-
du = @. 5cos(5x)cos(7x) - 7sin(5x)sin(7x);
40+
u = @. sin(5x)cos(7x);
41+
du = @. 5cos(5x)cos(7x) - 7sin(5x)sin(7x);
42+
d2u = @. 5(-5sin(5x)cos(7x) -7cos(5x)sin(7x)) +
43+
- 7(5cos(5x)sin(7x) + 7sin(5x)cos(7x))
4144

4245
@test (Dx * u, du; atol=1e-8)
46+
@test (D2x * u, d2u; atol=1e-8)
47+
48+
v = copy(u); @test (mul!(v, D2x, u), d2u; atol=1e-8)
4349
v = copy(u); @test (mul!(v, Dx, u), du; atol=1e-8)
4450

4551
itr = inv(ftr)

0 commit comments

Comments
 (0)