Skip to content

Commit 4145a37

Browse files
Merge pull request #96 from vpuri3/oop
allow functionoperator to specify both in-place/ out-of-place mode
2 parents 0b49024 + 56a3cd2 commit 4145a37

File tree

8 files changed

+72
-397
lines changed

8 files changed

+72
-397
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SciMLOperators"
22
uuid = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
33
authors = ["xtalax <[email protected]>"]
4-
version = "0.1.12"
4+
version = "0.1.13"
55

66
[deps]
77
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
@@ -11,6 +11,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1212
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1313
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
14+
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
1415

1516
[compat]
1617
ArrayInterfaceCore = "0.1"

src/SciMLOperators.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import StaticArrays
77
import SparseArrays
88
import ArrayInterfaceCore
99
import Base: ReshapedArray
10+
import Tricks: static_hasmethod
1011
import Lazy: @forward
1112
import Setfield: @set!
1213

src/func.jl

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""
33
Matrix free operators (given by a function)
44
"""
5-
mutable struct FunctionOperator{isinplace,T,F,Fa,Fi,Fai,Tr,P,Tt,C} <: AbstractSciMLOperator{T}
5+
mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: AbstractSciMLOperator{T}
66
""" Function with signature op(u, p, t) and (if isinplace) op(du, u, p, t) """
77
op::F
88
""" Adjoint operator"""
@@ -22,7 +22,8 @@ mutable struct FunctionOperator{isinplace,T,F,Fa,Fi,Fai,Tr,P,Tt,C} <: AbstractSc
2222
""" Cache """
2323
cache::C
2424

25-
function FunctionOperator(op,
25+
function FunctionOperator(
26+
op,
2627
op_adjoint,
2728
op_inverse,
2829
op_adjoint_inverse,
@@ -34,11 +35,14 @@ mutable struct FunctionOperator{isinplace,T,F,Fa,Fi,Fai,Tr,P,Tt,C} <: AbstractSc
3435
)
3536

3637
iip = traits.isinplace
38+
oop = traits.outofplace
3739
T = traits.T
3840

3941
isset = cache !== nothing
4042

41-
new{iip,
43+
new{
44+
iip,
45+
oop,
4246
T,
4347
typeof(op),
4448
typeof(op_adjoint),
@@ -62,43 +66,52 @@ mutable struct FunctionOperator{isinplace,T,F,Fa,Fi,Fai,Tr,P,Tt,C} <: AbstractSc
6266
end
6367
end
6468

65-
function FunctionOperator(op;
69+
function FunctionOperator(op,
70+
input::AbstractVecOrMat,
71+
output::AbstractVecOrMat;
6672

67-
# necessary
68-
isinplace=nothing,
69-
T=nothing,
70-
size=nothing,
73+
isinplace::Union{Nothing,Bool}=nothing,
74+
outofplace::Union{Nothing,Bool}=nothing,
75+
T::Union{Type{<:Number},Nothing}=nothing,
7176

72-
input_prototype=nothing,
73-
output_prototype=nothing,
74-
75-
# optional
7677
op_adjoint=nothing,
7778
op_inverse=nothing,
7879
op_adjoint_inverse=nothing,
7980

8081
p=nothing,
81-
t=nothing,
82+
t::Union{Number,Nothing}=nothing,
8283

8384
# traits
8485
opnorm=nothing,
85-
issymmetric=false,
86-
ishermitian=false,
87-
isposdef=false,
86+
issymmetric::Bool=false,
87+
ishermitian::Bool=false,
88+
isposdef::Bool=false,
8889
)
8990

90-
isinplace isa Nothing && @error "Please provide a funciton signature
91-
by specifying `isinplace` as either `true`, or `false`.
92-
If `isinplace = false`, the signature is `op(u, p, t)`,
93-
and if `isinplace = true`, the signature is `op(du, u, p, t)`.
94-
Further, it is assumed that the function call would be nonallocating
95-
when called in-place"
96-
T isa Nothing && @error "Please provide a Number type for the Operator"
97-
size isa Nothing && @error "Please provide a size (m, n)"
98-
if (input_prototype isa Nothing) | (output_prototype isa Nothing)
99-
@error "Please provide input/out prototypes vectors/arrays."
91+
sz = (size(output, 1), size(input, 1))
92+
T = T isa Nothing ? promote_type(eltype.((input, output))...) : T
93+
t = t isa Nothing ? zero(real(T)) : t
94+
95+
isinplace = if isinplace isa Nothing
96+
static_hasmethod(op, typeof((output, input, p, t)))
97+
else
98+
isinplace
99+
end
100+
101+
outofplace = if outofplace isa Nothing
102+
static_hasmethod(op, typeof((input, p, t)))
103+
else
104+
outofplace
100105
end
101106

107+
if !isinplace & !outofplace
108+
@error "Please provide a funciton with signatures `op(u, p, t)` for applying
109+
the operator out-of-place, and/or the signature is `op(du, u, p, t)` for
110+
in-place application."
111+
end
112+
113+
T isa Nothing && @error "Please provide a Number type for the Operator"
114+
102115
isreal = T <: Real
103116
selfadjoint = ishermitian | (isreal & issymmetric)
104117
adjointable = !(op_adjoint isa Nothing) | selfadjoint
@@ -112,24 +125,20 @@ function FunctionOperator(op;
112125
op_adjoint_inverse = op_inverse
113126
end
114127

115-
t = t isa Nothing ? zero(T) : t
116-
117128
traits = (;
118129
opnorm = opnorm,
119130
issymmetric = issymmetric,
120131
ishermitian = ishermitian,
121132
isposdef = isposdef,
122133

123134
isinplace = isinplace,
135+
outofplace = outofplace,
124136
T = T,
125-
size = size,
137+
size = sz,
126138
)
127139

128-
cache = (
129-
zero(input_prototype),
130-
zero(output_prototype),
131-
)
132-
isset = cache === nothing
140+
cache = zero.((input, output))
141+
isset = true
133142

134143
FunctionOperator(
135144
op,
@@ -260,16 +269,16 @@ has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothin
260269
# TODO - FunctionOperator, Base.conj, transpose
261270

262271
# operator application
263-
Base.:*(L::FunctionOperator{false}, u::AbstractVecOrMat) = L.op(u, L.p, L.t)
264-
Base.:\(L::FunctionOperator{false}, u::AbstractVecOrMat) = L.op_inverse(u, L.p, L.t)
272+
Base.:*(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op(u, L.p, L.t)
273+
Base.:\(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op_inverse(u, L.p, L.t)
265274

266-
function Base.:*(L::FunctionOperator{true}, u::AbstractVecOrMat)
275+
function Base.:*(L::FunctionOperator{true,false}, u::AbstractVecOrMat)
267276
_, co = L.cache
268277
du = zero(co)
269278
L.op(du, u, L.p, L.t)
270279
end
271280

272-
function Base.:\(L::FunctionOperator{true}, u::AbstractVecOrMat)
281+
function Base.:\(L::FunctionOperator{true,false}, u::AbstractVecOrMat)
273282
ci, _ = L.cache
274283
du = zero(ci)
275284
L.op_inverse(du, u, L.p, L.t)

test/func.jl

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ using SciMLOperators, LinearAlgebra
33
using Random
44

55
using SciMLOperators: InvertibleOperator,
6-
using FFTW
76

87
Random.seed!(0)
98
N = 8
@@ -26,15 +25,7 @@ K = 12
2625
f2i(du, u, p, t) = ldiv!(du, F, u)
2726

2827
# out of place
29-
op1 = FunctionOperator(
30-
f1;
31-
32-
isinplace=false,
33-
T=Float64,
34-
size=(N,N),
35-
36-
input_prototype=u,
37-
output_prototype=A*u,
28+
op1 = FunctionOperator(f1, u, A*u;
3829

3930
op_inverse=f1i,
4031

@@ -45,15 +36,7 @@ K = 12
4536
)
4637

4738
# in place
48-
op2 = FunctionOperator(
49-
f2;
50-
51-
isinplace=true,
52-
T=Float64,
53-
size=(N,N),
54-
55-
input_prototype=u,
56-
output_prototype=A*u,
39+
op2 = FunctionOperator(f2, u, A*u;
5740

5841
op_inverse=f2i,
5942

@@ -96,19 +79,7 @@ end
9679

9780
f(du,u,p,t) = mul!(du, Diagonal(p*t), u)
9881

99-
op = FunctionOperator(
100-
f;
101-
102-
isinplace=true,
103-
T=Float64,
104-
size=(N,N),
105-
106-
input_prototype=u,
107-
output_prototype=u,
108-
109-
p=p*0.0,
110-
t=0.0,
111-
)
82+
op = FunctionOperator(f, u, u; p=zero(p), t=zero(t))
11283

11384
ans = @. u * p * t
11485
@test op(u,p,t) ans

test/matrix.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,23 @@ end
7878
v=copy(u); @test D(v,u,p,t) ans
7979
end
8080

81+
@testset "Batched Diagonal Operator" begin
82+
u = rand(N,K)
83+
d = rand(N,K)
84+
α = rand()
85+
β = rand()
86+
87+
L = DiagonalOperator(d)
88+
89+
@test L * u d .* u
90+
v=rand(N,K); @test mul!(v, L, u) d .* u
91+
v=rand(N,K); w=copy(v); @test mul!(v, L, u, α, β) α*(d .* u) + β*w
92+
93+
@test L \ u d .\ u
94+
v=rand(N,K); @test ldiv!(v, L, u) d .\ u
95+
v=copy(u); @test ldiv!(L, u) d .\ v
96+
end
97+
8198
@testset "AffineOperator" begin
8299
u = rand(N,K)
83100
A = rand(N,N)

0 commit comments

Comments
 (0)