Skip to content

Commit 2039879

Browse files
committed
feat: add callable structs
1 parent 7e3c585 commit 2039879

File tree

3 files changed

+94
-56
lines changed

3 files changed

+94
-56
lines changed

lib/SciMLJacobianOperators/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
99
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1010
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1111
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
12+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1213
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1314
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
1415
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
@@ -19,6 +20,7 @@ ConcreteStructs = "0.2.3"
1920
ConstructionBase = "1.5.8"
2021
DifferentiationInterface = "0.5.17"
2122
FastClosures = "0.3.2"
23+
LinearAlgebra = "1.11.0"
2224
SciMLOperators = "0.3.10"
2325
Setfield = "1.1.1"
2426
julia = "1.10"

lib/SciMLJacobianOperators/src/SciMLJacobianOperators.jl

Lines changed: 92 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ const DI = DifferentiationInterface
1313
const True = Val(true)
1414
const False = Val(false)
1515

16+
abstract type AbstractJacobianOperator{T} <: AbstractSciMLOperator{T} end
17+
1618
abstract type AbstractMode end
1719

1820
struct VJP <: AbstractMode end
@@ -21,17 +23,20 @@ struct JVP <: AbstractMode end
2123
flip_mode(::VJP) = JVP()
2224
flip_mode(::JVP) = VJP()
2325

24-
@concrete struct JacobianOperator{iip, T <: Real} <: AbstractSciMLOperator{T}
26+
@concrete struct JacobianOperator{iip, T <: Real} <: AbstractJacobianOperator{T}
2527
mode <: AbstractMode
2628

2729
jvp_op
2830
vjp_op
2931

3032
size
31-
jvp_extras
32-
vjp_extras
33+
34+
output_cache
35+
input_cache
3336
end
3437

38+
SciMLBase.isinplace(::JacobianOperator{iip}) where {iip} = iip
39+
3540
function ConstructionBase.constructorof(::Type{<:JacobianOperator{iip, T}}) where {iip, T}
3641
return JacobianOperator{iip, T}
3742
end
@@ -42,6 +47,9 @@ Base.size(J::JacobianOperator, d::Integer) = J.size[d]
4247
for op in (:adjoint, :transpose)
4348
@eval function Base.$(op)(operator::JacobianOperator)
4449
@set! operator.mode = flip_mode(operator.mode)
50+
(; output_cache, input_cache) = operator
51+
@set! operator.output_cache = input_cache
52+
@set! operator.input_cache = output_cache
4553
return operator
4654
end
4755
end
@@ -53,16 +61,66 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
5361
f = prob.f
5462
iip = SciMLBase.isinplace(prob)
5563
T = promote_type(eltype(u), eltype(fu))
56-
fₚ = SciMLBase.JacobianWrapper{iip}(f, prob.p)
5764

58-
vjp_op, vjp_extras = prepare_vjp(skip_vjp, prob, f, u, fu; autodiff = vjp_autodiff)
59-
jvp_op, jvp_extras = prepare_jvp(skip_jvp, prob, f, u, fu; autodiff = jvp_autodiff)
65+
vjp_op = prepare_vjp(skip_vjp, prob, f, u, fu; autodiff = vjp_autodiff)
66+
jvp_op = prepare_jvp(skip_jvp, prob, f, u, fu; autodiff = jvp_autodiff)
67+
68+
output_cache = iip ? similar(fu, T) : nothing
69+
input_cache = iip ? similar(u, T) : nothing
6070

6171
return JacobianOperator{iip, T}(
62-
JVP(), jvp_op, vjp_op, (length(fu), length(u)), jvp_extras, vjp_extras)
72+
JVP(), jvp_op, vjp_op, (length(fu), length(u)), output_cache, input_cache)
6373
end
6474

65-
prepare_vjp(::Val{true}, args...; kwargs...) = nothing, nothing
75+
function (op::JacobianOperator)(v, u, p)
76+
if op.mode isa VJP
77+
if SciMLBase.isinplace(op)
78+
res = zero(op.output_cache)
79+
op.vjp_op(res, v, u, p)
80+
return res
81+
end
82+
return op.vjp_op(v, u, p)
83+
else
84+
if SciMLBase.isinplace(op)
85+
res = zero(op.output_cache)
86+
op.jvp_op(res, v, u, p)
87+
return res
88+
end
89+
return op.jvp_op(v, u, p)
90+
end
91+
end
92+
93+
function (op::JacobianOperator)(::Number, ::Number, _, __)
94+
error("Inplace Jacobian Operator not possible for scalars.")
95+
end
96+
97+
function (op::JacobianOperator)(Jv, v, u, p)
98+
if op.mode isa VJP
99+
if SciMLBase.isinplace(op)
100+
op.vjp_op(Jv, v, u, p)
101+
return
102+
end
103+
copyto!(Jv, op.vjp_op(v, u, p))
104+
return
105+
else
106+
if SciMLBase.isinplace(op)
107+
op.jvp_op(Jv, v, u, p)
108+
return
109+
end
110+
copyto!(Jv, op.jvp_op(v, u, p))
111+
return
112+
end
113+
end
114+
115+
function VecJacOperator(args...; autodiff = nothing, kwargs...)
116+
return JacobianOperator(args...; kwargs..., skip_jvp = True, vjp_autodiff = autodiff)'
117+
end
118+
119+
function JacVecOperator(args...; autodiff = nothing, kwargs...)
120+
return JacobianOperator(args...; kwargs..., skip_vjp = True, jvp_autodiff = autodiff)
121+
end
122+
123+
prepare_vjp(::Val{true}, args...; kwargs...) = nothing
66124

67125
function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
68126
f::AbstractNonlinearFunction, u::Number, fu::Number; autodiff = nothing)
@@ -71,20 +129,19 @@ end
71129

72130
function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
73131
f::AbstractNonlinearFunction, u, fu; autodiff = nothing)
74-
SciMLBase.has_vjp(f) && return f.vjp, nothing
132+
SciMLBase.has_vjp(f) && return f.vjp
75133

76134
if autodiff === nothing && SciMLBase.has_jac(f)
77135
if SciMLBase.isinplace(f)
78-
vjp_extras = (; jac_cache = similar(u, eltype(fu), length(fu), length(u)))
79-
vjp_op = @closure (vJ, v, u, p, extras) -> begin
80-
f.jac(extras.jac_cache, u, p)
81-
mul!(vec(vJ), extras.jac_cache', vec(v))
136+
jac_cache = similar(u, eltype(fu), length(fu), length(u))
137+
return @closure (vJ, v, u, p) -> begin
138+
f.jac(jac_cache, u, p)
139+
mul!(vec(vJ), jac_cache', vec(v))
82140
return
83141
end
84142
return vjp_op, vjp_extras
85143
else
86-
vjp_op = @closure (v, u, p, _) -> reshape(f.jac(u, p)' * vec(v), size(u))
87-
return vjp_op, nothing
144+
return @closure (v, u, p) -> reshape(f.jac(u, p)' * vec(v), size(u))
88145
end
89146
end
90147

@@ -102,21 +159,16 @@ function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
102159
fu_cache = copy(fu)
103160
v_fake = copy(fu)
104161
di_extras = DI.prepare_pullback(fₚ, fu_cache, autodiff, u, v_fake)
105-
vjp_op = @closure (vJ, v, u, p, extras) -> begin
106-
DI.pullback!(
107-
fₚ, extras.fu_cache, reshape(vJ, size(u)), autodiff, u, v, extras.di_extras)
162+
return @closure (vJ, v, u, p) -> begin
163+
DI.pullback!(fₚ, fu_cache, reshape(vJ, size(u)), autodiff, u, v, di_extras)
108164
end
109-
return vjp_op, (; di_extras, fu_cache)
110165
else
111166
di_extras = DI.prepare_pullback(f, autodiff, u, fu)
112-
vjp_op = @closure (v, u, p, extras) -> begin
113-
return DI.pullback(f, autodiff, u, v, extras.di_extras)
114-
end
115-
return vjp_op, (; di_extras)
167+
return @closure (v, u, p) -> DI.pullback(f, autodiff, u, v, di_extras)
116168
end
117169
end
118170

119-
prepare_jvp(skip::Val{true}, args...; kwargs...) = nothing, nothing
171+
prepare_jvp(skip::Val{true}, args...; kwargs...) = nothing
120172

121173
function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
122174
f::AbstractNonlinearFunction, u::Number, fu::Number; autodiff = nothing)
@@ -125,20 +177,18 @@ end
125177

126178
function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
127179
f::AbstractNonlinearFunction, u, fu; autodiff = nothing)
128-
SciMLBase.has_vjp(f) && return f.vjp, nothing
180+
SciMLBase.has_vjp(f) && return f.vjp
129181

130182
if autodiff === nothing && SciMLBase.has_jac(f)
131183
if SciMLBase.isinplace(f)
132-
jvp_extras = (; jac_cache = similar(u, eltype(fu), length(fu), length(u)))
133-
jvp_op = @closure (Jv, v, u, p, extras) -> begin
134-
f.jac(extras.jac_cache, u, p)
135-
mul!(vec(Jv), extras.jac_cache, vec(v))
184+
jac_cache = similar(u, eltype(fu), length(fu), length(u))
185+
return @closure (Jv, v, u, p) -> begin
186+
f.jac(jac_cache, u, p)
187+
mul!(vec(Jv), jac_cache, vec(v))
136188
return
137189
end
138-
return jvp_op, jvp_extras
139190
else
140-
jvp_op = @closure (v, u, p, _) -> reshape(f.jac(u, p) * vec(v), size(u))
141-
return jvp_op, nothing
191+
return @closure (v, u, p, _) -> reshape(f.jac(u, p) * vec(v), size(u))
142192
end
143193
end
144194

@@ -155,43 +205,29 @@ function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
155205
if SciMLBase.isinplace(f)
156206
fu_cache = copy(fu)
157207
di_extras = DI.prepare_pushforward(fₚ, fu_cache, autodiff, u, u)
158-
jvp_op = @closure (Jv, v, u, p, extras) -> begin
159-
DI.pushforward!(fₚ, extras.fu_cache, reshape(Jv, size(extras.fu_cache)),
160-
autodiff, u, v, extras.di_extras)
208+
return @closure (Jv, v, u, p) -> begin
209+
DI.pushforward!(fₚ, fu_cache, reshape(Jv, size(fu_cache)), autodiff, u, v,
210+
di_extras)
211+
return
161212
end
162-
return jvp_op, (; di_extras, fu_cache)
163213
else
164-
di_extras = DI.prepare_pushforward(f, autodiff, u, u)
165-
jvp_op = @closure (v, u, p, extras) -> begin
166-
return DI.pushforward(f, autodiff, u, v, extras.di_extras)
167-
end
168-
return jvp_op, (; di_extras)
214+
di_extras = DI.prepare_pushforward(fₚ, autodiff, u, u)
215+
return @closure (v, u, p) -> DI.pushforward(fₚ, autodiff, u, v, di_extras)
169216
end
170217
end
171218

172219
function prepare_scalar_op(::Val{false}, prob::AbstractNonlinearProblem,
173220
f::AbstractNonlinearFunction, u::Number, fu::Number; autodiff = nothing)
174-
SciMLBase.has_vjp(f) && return f.vjp, nothing
175-
SciMLBase.has_jvp(f) && return f.jvp, nothing
176-
SciMLBase.has_jac(f) && return @closure((v, u, p, _)->f.jac(u, p) * v), nothing
221+
SciMLBase.has_vjp(f) && return f.vjp
222+
SciMLBase.has_jvp(f) && return f.jvp
223+
SciMLBase.has_jac(f) && return @closure((v, u, p)->f.jac(u, p) * v)
177224

178225
@assert autodiff!==nothing "`autodiff` must be provided if `f` doesn't have \
179226
analytic `vjp` or `jvp` or `jac`."
180227
# TODO: Once DI supports const params we can use `p`
181228
fₚ = Base.Fix2(f, prob.p)
182229
di_extras = DI.prepare_derivative(fₚ, autodiff, u)
183-
op = @closure (v, u, p, extras) -> begin
184-
return DI.derivative(fₚ, autodiff, u, extras.di_extras) * v
185-
end
186-
return op, (; di_extras)
187-
end
188-
189-
function VecJacOperator(args...; autodiff = nothing, kwargs...)
190-
return JacobianOperator(args...; kwargs..., skip_jvp = True, vjp_autodiff = autodiff)'
191-
end
192-
193-
function JacVecOperator(args...; autodiff = nothing, kwargs...)
194-
return JacobianOperator(args...; kwargs..., skip_vjp = True, jvp_autodiff = autodiff)
230+
return @closure (v, u, p) -> DI.derivative(fₚ, autodiff, u, di_extras) * v
195231
end
196232

197233
export JacobianOperator, VecJacOperator, JacVecOperator

lib/SciMLJacobianOperators/test/runtests.jl

Whitespace-only changes.

0 commit comments

Comments
 (0)