Skip to content

Commit d2897e7

Browse files
authored
Merge pull request #157 from vpuri3/5arg
Support `L(v, u, p, t, a, b)` corresponding to 5 arg `mul!`
2 parents ffc05f2 + bdd87e0 commit d2897e7

File tree

5 files changed

+103
-28
lines changed

5 files changed

+103
-28
lines changed

src/func.jl

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""
33
Matrix free operators (given by a function)
44
"""
5-
mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: AbstractSciMLOperator{T}
5+
mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: AbstractSciMLOperator{T}
66
""" Function with signature op(u, p, t) and (if isinplace) op(du, u, p, t) """
77
op::F
88
""" Adjoint operator"""
@@ -33,11 +33,13 @@ mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: Abst
3333

3434
iip = traits.isinplace
3535
oop = traits.outofplace
36+
mul5 = traits.has_mul5
3637
T = traits.T
3738

3839
new{
3940
iip,
4041
oop,
42+
mul5,
4143
T,
4244
typeof(op),
4345
typeof(op_adjoint),
@@ -88,6 +90,8 @@ function FunctionOperator(op,
8890

8991
isinplace::Union{Nothing,Bool}=nothing,
9092
outofplace::Union{Nothing,Bool}=nothing,
93+
has_mul5::Union{Nothing,Bool}=nothing,
94+
cache::Union{Nothing, NTuple{2}}=nothing,
9195
T::Union{Type{<:Number},Nothing}=nothing,
9296

9397
op_adjoint=nothing,
@@ -109,21 +113,34 @@ function FunctionOperator(op,
109113
)
110114

111115
sz = (size(output, 1), size(input, 1))
112-
T = T isa Nothing ? promote_type(eltype.((input, output))...) : T
113-
t = t isa Nothing ? zero(real(T)) : t
116+
T = isnothing(T) ? promote_type(eltype.((input, output))...) : T
117+
t = isnothing(t) ? zero(real(T)) : t
114118

115-
isinplace = if isinplace isa Nothing
119+
isinplace = if isnothing(isinplace)
116120
static_hasmethod(op, typeof((output, input, p, t)))
117121
else
118122
isinplace
119123
end
120124

121-
outofplace = if outofplace isa Nothing
125+
outofplace = if isnothing(outofplace)
122126
static_hasmethod(op, typeof((input, p, t)))
123127
else
124128
outofplace
125129
end
126130

131+
has_mul5 = if isnothing(has_mul5)
132+
has_mul5 = true
133+
for f in (
134+
op, op_adjoint, op_inverse, op_adjoint_inverse,
135+
)
136+
if !isnothing(f)
137+
has_mul5 *= static_hasmethod(f, typeof((output, input, p, t, t, t)))
138+
end
139+
end
140+
141+
has_mul5
142+
end
143+
127144
if !isinplace & !outofplace
128145
@error "Please provide a funciton with signatures `op(u, p, t)` for applying
129146
the operator out-of-place, and/or the signature is `op(du, u, p, t)` for
@@ -155,12 +172,12 @@ function FunctionOperator(op,
155172

156173
isinplace = isinplace,
157174
outofplace = outofplace,
175+
has_mul5 = has_mul5,
176+
ifcache = ifcache,
158177
T = T,
159178
size = sz,
160179
)
161180

162-
cache = nothing
163-
164181
L = FunctionOperator(
165182
op,
166183
op_adjoint,
@@ -172,7 +189,11 @@ function FunctionOperator(op,
172189
cache,
173190
)
174191

175-
ifcache ? cache_operator(L, input, output) : L
192+
if ifcache & isnothing(L.cache)
193+
L = cache_operator(L, input, output)
194+
end
195+
196+
L
176197
end
177198

178199
function update_coefficients(L::FunctionOperator, u, p, t)
@@ -204,7 +225,13 @@ function update_coefficients!(L::FunctionOperator, u, p, t)
204225
nothing
205226
end
206227

228+
function iscached(L::FunctionOperator)
229+
L.traits.ifcache ? !isnothing(L.cache) : !L.traits.ifcache
230+
!isnothing(L.cache)
231+
end
232+
207233
function cache_self(L::FunctionOperator, u::AbstractVecOrMat, v::AbstractVecOrMat)
234+
L.traits.ifcache && @warn "you are allocating cache for a FunctionOperator for which ifcache = false."
208235
@set! L.cache = zero.((u, v))
209236
L
210237
end
@@ -365,7 +392,7 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{false}, u::
365392
@error "LinearAlgebra.mul! not defined for out-of-place FunctionOperators"
366393
end
367394

368-
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::AbstractVecOrMat, α, β)
395+
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true, oop, false}, u::AbstractVecOrMat, α, β) where{oop}
369396
_, co = L.cache
370397

371398
copy!(co, v)
@@ -374,6 +401,10 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::A
374401
axpy!(β, co, v)
375402
end
376403

404+
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true, oop, true}, u::AbstractVecOrMat, α, β) where{oop}
405+
L.op(v, u, L.p, L.t, α, β)
406+
end
407+
377408
function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::AbstractVecOrMat)
378409
L.op_inverse(v, u, L.p, L.t)
379410
end

src/interface.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,24 @@ end
2828

2929
(L::AbstractSciMLOperator)(u, p, t) = (update_coefficients!(L, u, p, t); L * u)
3030
(L::AbstractSciMLOperator)(du, u, p, t) = (update_coefficients!(L, u, p, t); mul!(du, L, u))
31+
(L::AbstractSciMLOperator)(du, u, p, t, α, β) = (update_coefficients!(L, u, p, t); mul!(du, L, u, α, β))
3132

3233
###
3334
# caching interface
3435
###
3536

37+
getops(L) = ()
38+
3639
function iscached(L::AbstractSciMLOperator)
40+
3741
has_cache = hasfield(typeof(L), :cache) # TODO - confirm this is static
38-
isset = has_cache ? L.cache !== nothing : true
42+
isset = has_cache ? !isnothing(L.cache) : true
3943

4044
return isset & all(iscached, getops(L))
4145
end
4246

4347
iscached(L) = true
48+
4449
iscached(::Union{
4550
# LinearAlgebra
4651
AbstractMatrix,
@@ -61,22 +66,22 @@ arguments:
6166
in :: AbstractVecOrMat input prototype to L
6267
out :: (optional) AbstractVecOrMat output prototype to L
6368
"""
64-
cache_operator
69+
function cache_operator end
6570

6671
cache_operator(L, u) = L
67-
cache_operatro(L, u, v) = L
68-
cache_self(L::AbstractSciMLOperator, uv::AbstractVecOrMat...) = L
69-
cache_internals(L::AbstractSciMLOperator, uv::AbstractVecOrMat...) = L
72+
cache_operator(L, u, v) = L
73+
cache_self(L::AbstractSciMLOperator, ::AbstractVecOrMat...) = L
74+
cache_internals(L::AbstractSciMLOperator, ::AbstractVecOrMat...) = L
75+
76+
function cache_operator(L::AbstractSciMLOperator, u::AbstractVecOrMat, v::AbstractVecOrMat)
7077

71-
function cache_operator(L::AbstractSciMLOperator,
72-
u::AbstractVecOrMat,
73-
v::AbstractVecOrMat)
7478
L = cache_self(L, u, v)
7579
L = cache_internals(L, u, v)
7680
L
7781
end
7882

7983
function cache_operator(L::AbstractSciMLOperator, u::AbstractVecOrMat)
84+
8085
L = cache_self(L, u)
8186
L = cache_internals(L, u)
8287
L

test/func.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@ K = 12
1717

1818
A = rand(N,N) |> Symmetric
1919
F = lu(A)
20+
Ai = inv(A)
2021

2122
f1(u, p, t) = A * u
2223
f1i(u, p, t) = A \ u
2324

2425
f2(du, u, p, t) = mul!(du, A, u)
26+
f2(du, u, p, t, α, β) = mul!(du, A, u, α, β)
2527
f2i(du, u, p, t) = ldiv!(du, F, u)
28+
f2i(du, u, p, t, α, β) = mul!(du, Ai, u, α, β)
2629

2730
# out of place
2831
op1 = FunctionOperator(f1, u, A*u;
@@ -51,6 +54,7 @@ K = 12
5154
ishermitian=true,
5255
isposdef=true,
5356
)
57+
5458
@test issquare(op1)
5559
@test issquare(op2)
5660

@@ -76,6 +80,12 @@ K = 12
7680
@test !iscached(op1)
7781
@test !iscached(op2)
7882

83+
@test !op1.traits.has_mul5
84+
@test op2.traits.has_mul5
85+
86+
# 5-arg mul! (w/o cache)
87+
v = rand(N,K); w=copy(v); @test α*(A*u)+ β*w mul!(v, op2, u, α, β)
88+
7989
op1 = cache_operator(op1, u, A * u)
8090
op2 = cache_operator(op2, u, A * u)
8191

test/matrix.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,22 +65,27 @@ end
6565
p = rand(N)
6666
t = rand()
6767

68+
α = rand()
69+
β = rand()
70+
6871
L = MatrixOperator(zeros(N,N);
6972
update_func = (A,u,p,t) -> (A .= p*p'; nothing)
7073
)
7174

7275
@test !isconstant(L)
7376

74-
A = p*p'
75-
ans = A * u
76-
@test L(u,p,t) ans
77-
v=copy(u); @test L(v,u,p,t) ans
77+
A = p * p'
78+
@test L(u, p, t) A * u
79+
v=copy(u); @test L(v, u, p, t) A * u
80+
v=rand(N,K); w=copy(v); @test L(v, u, p, t, α, β) α*A*u + β*w
7881
end
7982

8083
@testset "DiagonalOperator update test" begin
8184
u = rand(N,K)
8285
p = rand(N)
8386
t = rand()
87+
α = rand()
88+
β = rand()
8489

8590
D = DiagonalOperator(zeros(N);
8691
update_func = (diag,u,p,t) -> (diag .= p*t; nothing)
@@ -93,6 +98,7 @@ end
9398
ans = Diagonal(p*t) * u
9499
@test D(u,p,t) ans
95100
v=copy(u); @test D(v,u,p,t) ans
101+
v=rand(N,K); w=copy(v); @test D(v, u, p, t, α, β) α*ans + β*w
96102
end
97103

98104
@testset "Batched Diagonal Operator" begin
@@ -173,6 +179,8 @@ end
173179
u = rand(N,K)
174180
p = rand(N)
175181
t = rand()
182+
α = rand()
183+
β = rand()
176184

177185
L = AffineOperator(A, B, b;
178186
update_func = (b,u,p,t) -> (b .= Diagonal(p*t)*b; nothing)
@@ -186,6 +194,9 @@ end
186194
b = Diagonal(p*t)*b
187195
ans = A * u + B * b
188196
v=copy(u); @test L(v,u,p,t) ans
197+
b = Diagonal(p*t)*b
198+
ans = A * u + B * b
199+
v=rand(N,K); w=copy(v); @test L(v, u, p, t, α, β) α*ans + β*w
189200
end
190201

191202
@testset "TensorProductOperator" begin

test/scalar.jl

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,26 +61,44 @@ K = 12
6161
end
6262

6363
@testset "ScalarOperator update test" begin
64-
u = ones(N,K)
65-
v = zeros(N,K)
64+
u = rand(N,K)
65+
v = rand(N,K)
6666
p = rand()
6767
t = rand()
68+
a = rand()
69+
b = rand()
6870

6971
α = ScalarOperator(0.0; update_func=(a,u,p,t) -> p)
7072
β = ScalarOperator(0.0; update_func=(a,u,p,t) -> t)
7173

7274
@test !isconstant(α)
7375
@test !isconstant(β)
7476

75-
@test α(u,p,t) p * u
76-
@test α(v,u,p,t) p * u
77+
@test convert(Number, α) 0.0
78+
@test convert(Number, β) 0.0
79+
80+
update_coefficients!(α, u, p, t)
81+
update_coefficients!(β, u, p, t)
82+
83+
@test convert(Number, α) p
84+
@test convert(Number, β) t
85+
86+
@test α(u, p, t) p * u
87+
v=rand(N,K); @test α(v, u, p, t) p * u
88+
v=rand(N,K); w=copy(v); @test α(v, u, p, t, a, b) a*p*u + b*w
89+
90+
@test β(u, p, t) t * u
91+
v=rand(N,K); @test β(v, u, p, t) t * u
92+
v=rand(N,K); w=copy(v); @test β(v, u, p, t, a, b) a*t*u + b*w
7793

7894
num = α + 2 / β * 3 - 4
7995
val = p + 2 / t * 3 - 4
8096

81-
@test num(u,p,t) val * u
82-
@test num(v,u,p,t) val * u
83-
8497
@test convert(Number, num) val
98+
99+
@test num(u, p, t) val * u
100+
v=rand(N,K); @test num(v, u, p, t) val * u
101+
v=rand(N,K); w=copy(v); @test num(v, u, p, t, a, b) a*val*u + b*w
102+
85103
end
86104
#

0 commit comments

Comments
 (0)