Skip to content

Commit 1a8cd78

Browse files
committed
feat: add SciMLJacobianOperators package
1 parent e457a98 commit 1a8cd78

File tree

4 files changed

+252
-0
lines changed

4 files changed

+252
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2323
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2424
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2525
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
26+
SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e"
2627
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
2728
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2829
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

lib/SciMLJacobianOperators/LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2024 SciML
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
name = "SciMLJacobianOperators"
2+
uuid = "19f34311-ddf3-4b8b-af20-060888a46c0e"
3+
authors = ["Avik Pal <[email protected]> and contributors"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
8+
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
9+
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
10+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
11+
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
12+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
13+
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
14+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
15+
16+
[compat]
17+
ADTypes = "1.8.1"
18+
ConcreteStructs = "0.2.3"
19+
ConstructionBase = "1.5.8"
20+
DifferentiationInterface = "0.5.17"
21+
FastClosures = "0.3.2"
22+
SciMLOperators = "0.3.10"
23+
Setfield = "1.1.1"
24+
julia = "1.10"
25+
26+
[extras]
27+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
28+
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
29+
30+
[targets]
31+
test = ["Test", "TestItemRunner"]
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
module SciMLJacobianOperators
2+
3+
using ADTypes: ADTypes
4+
using ConcreteStructs: @concrete
5+
using ConstructionBase: ConstructionBase
6+
using DifferentiationInterface: DifferentiationInterface
7+
using FastClosures: @closure
8+
using SciMLBase: SciMLBase, AbstractNonlinearProblem, AbstractNonlinearFunction
9+
using SciMLOperators: AbstractSciMLOperator
10+
using Setfield: @set!
11+
12+
const DI = DifferentiationInterface
13+
const True = Val(true)
14+
const False = Val(false)
15+
16+
abstract type AbstractMode end
17+
18+
struct VJP <: AbstractMode end
19+
struct JVP <: AbstractMode end
20+
21+
flip_mode(::VJP) = JVP()
22+
flip_mode(::JVP) = VJP()
23+
24+
@concrete struct JacobianOperator{iip, T <: Real} <: AbstractSciMLOperator{T}
25+
mode <: AbstractMode
26+
27+
jvp_op
28+
vjp_op
29+
30+
size
31+
jvp_extras
32+
vjp_extras
33+
end
34+
35+
function ConstructionBase.constructorof(::Type{<:JacobianOperator{iip, T}}) where {iip, T}
36+
return JacobianOperator{iip, T}
37+
end
38+
39+
Base.size(J::JacobianOperator) = J.size
40+
Base.size(J::JacobianOperator, d::Integer) = J.size[d]
41+
42+
for op in (:adjoint, :transpose)
43+
@eval function Base.$(op)(operator::JacobianOperator)
44+
@set! operator.mode = flip_mode(operator.mode)
45+
return operator
46+
end
47+
end
48+
49+
function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff = nothing,
50+
vjp_autodiff = nothing, skip_vjp::Val = False, skip_jvp::Val = False)
51+
@assert !(skip_vjp === True && skip_jvp === True) "Cannot skip both vjp and jvp \
52+
construction."
53+
f = prob.f
54+
iip = SciMLBase.isinplace(prob)
55+
T = promote_type(eltype(u), eltype(fu))
56+
fₚ = SciMLBase.JacobianWrapper{iip}(f, prob.p)
57+
58+
vjp_op, vjp_extras = prepare_vjp(skip_vjp, prob, f, u, fu; autodiff = vjp_autodiff)
59+
jvp_op, jvp_extras = prepare_jvp(skip_jvp, prob, f, u, fu; autodiff = jvp_autodiff)
60+
61+
return JacobianOperator{iip, T}(
62+
JVP(), jvp_op, vjp_op, (length(fu), length(u)), jvp_extras, vjp_extras)
63+
end
64+
65+
prepare_vjp(::Val{true}, args...; kwargs...) = nothing, nothing
66+
67+
function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
68+
f::AbstractNonlinearFunction, u::Number, fu::Number; autodiff = nothing)
69+
return prepare_scalar_op(Val(false), prob, f, u, fu; autodiff)
70+
end
71+
72+
function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,
73+
f::AbstractNonlinearFunction, u, fu; autodiff = nothing)
74+
SciMLBase.has_vjp(f) && return f.vjp, nothing
75+
76+
if autodiff === nothing && SciMLBase.has_jac(f)
77+
if SciMLBase.isinplace(f)
78+
vjp_extras = (; jac_cache = similar(u, eltype(fu), length(fu), length(u)))
79+
vjp_op = @closure (vJ, v, u, p, extras) -> begin
80+
f.jac(extras.jac_cache, u, p)
81+
mul!(vec(vJ), extras.jac_cache', vec(v))
82+
return
83+
end
84+
return vjp_op, vjp_extras
85+
else
86+
vjp_op = @closure (v, u, p, _) -> reshape(f.jac(u, p)' * vec(v), size(u))
87+
return vjp_op, nothing
88+
end
89+
end
90+
91+
@assert autodiff!==nothing "`vjp_autodiff` must be provided if `f` doesn't have \
92+
analytic `vjp` or `jac`."
93+
94+
if ADTypes.mode(autodiff) isa ADTypes.ForwardMode
95+
@warn "AD Backend: $(autodiff) is a Forward Mode backend. Computing VJPs using \
96+
this will be slow!"
97+
end
98+
99+
# TODO: Once DI supports const params we can use `p`
100+
fₚ = SciMLBase.JacobianWrapper{SciMLBase.isinplace(f)}(f, prob.p)
101+
if SciMLBase.isinplace(f)
102+
fu_cache = copy(fu)
103+
v_fake = copy(fu)
104+
di_extras = DI.prepare_pullback(fₚ, fu_cache, autodiff, u, v_fake)
105+
vjp_op = @closure (vJ, v, u, p, extras) -> begin
106+
DI.pullback!(
107+
fₚ, extras.fu_cache, reshape(vJ, size(u)), autodiff, u, v, extras.di_extras)
108+
end
109+
return vjp_op, (; di_extras, fu_cache)
110+
else
111+
di_extras = DI.prepare_pullback(f, autodiff, u, fu)
112+
vjp_op = @closure (v, u, p, extras) -> begin
113+
return DI.pullback(f, autodiff, u, v, extras.di_extras)
114+
end
115+
return vjp_op, (; di_extras)
116+
end
117+
end
118+
119+
prepare_jvp(skip::Val{true}, args...; kwargs...) = nothing, nothing
120+
121+
function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
122+
f::AbstractNonlinearFunction, u::Number, fu::Number; autodiff = nothing)
123+
return prepare_scalar_op(Val(false), prob, f, u, fu; autodiff)
124+
end
125+
126+
function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,
127+
f::AbstractNonlinearFunction, u, fu; autodiff = nothing)
128+
SciMLBase.has_vjp(f) && return f.vjp, nothing
129+
130+
if autodiff === nothing && SciMLBase.has_jac(f)
131+
if SciMLBase.isinplace(f)
132+
jvp_extras = (; jac_cache = similar(u, eltype(fu), length(fu), length(u)))
133+
jvp_op = @closure (Jv, v, u, p, extras) -> begin
134+
f.jac(extras.jac_cache, u, p)
135+
mul!(vec(Jv), extras.jac_cache, vec(v))
136+
return
137+
end
138+
return jvp_op, jvp_extras
139+
else
140+
jvp_op = @closure (v, u, p, _) -> reshape(f.jac(u, p) * vec(v), size(u))
141+
return jvp_op, nothing
142+
end
143+
end
144+
145+
@assert autodiff!==nothing "`jvp_autodiff` must be provided if `f` doesn't have \
146+
analytic `vjp` or `jac`."
147+
148+
if ADTypes.mode(autodiff) isa ADTypes.ReverseMode
149+
@warn "AD Backend: $(autodiff) is a Reverse Mode backend. Computing JVPs using \
150+
this will be slow!"
151+
end
152+
153+
# TODO: Once DI supports const params we can use `p`
154+
fₚ = SciMLBase.JacobianWrapper{SciMLBase.isinplace(f)}(f, prob.p)
155+
if SciMLBase.isinplace(f)
156+
fu_cache = copy(fu)
157+
di_extras = DI.prepare_pushforward(fₚ, fu_cache, autodiff, u, u)
158+
jvp_op = @closure (Jv, v, u, p, extras) -> begin
159+
DI.pushforward!(fₚ, extras.fu_cache, reshape(Jv, size(extras.fu_cache)),
160+
autodiff, u, v, extras.di_extras)
161+
end
162+
return jvp_op, (; di_extras, fu_cache)
163+
else
164+
di_extras = DI.prepare_pushforward(f, autodiff, u, u)
165+
jvp_op = @closure (v, u, p, extras) -> begin
166+
return DI.pushforward(f, autodiff, u, v, extras.di_extras)
167+
end
168+
return jvp_op, (; di_extras)
169+
end
170+
end
171+
172+
function prepare_scalar_op(::Val{false}, prob::AbstractNonlinearProblem,
173+
f::AbstractNonlinearFunction, u::Number, fu::Number; autodiff = nothing)
174+
SciMLBase.has_vjp(f) && return f.vjp, nothing
175+
SciMLBase.has_jvp(f) && return f.jvp, nothing
176+
SciMLBase.has_jac(f) && return @closure((v, u, p, _)->f.jac(u, p) * v), nothing
177+
178+
@assert autodiff!==nothing "`autodiff` must be provided if `f` doesn't have \
179+
analytic `vjp` or `jvp` or `jac`."
180+
# TODO: Once DI supports const params we can use `p`
181+
fₚ = Base.Fix2(f, prob.p)
182+
di_extras = DI.prepare_derivative(fₚ, autodiff, u)
183+
op = @closure (v, u, p, extras) -> begin
184+
return DI.derivative(fₚ, autodiff, u, extras.di_extras) * v
185+
end
186+
return op, (; di_extras)
187+
end
188+
189+
function VecJacOperator(args...; autodiff = nothing, kwargs...)
190+
return JacobianOperator(args...; kwargs..., skip_jvp = True, vjp_autodiff = autodiff)'
191+
end
192+
193+
function JacVecOperator(args...; autodiff = nothing, kwargs...)
194+
return JacobianOperator(args...; kwargs..., skip_vjp = True, jvp_autodiff = autodiff)
195+
end
196+
197+
export JacobianOperator, VecJacOperator, JacVecOperator
198+
199+
end

0 commit comments

Comments
 (0)