Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/SciMLJacobianOperators/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ ADTypes = "1.8.1"
Aqua = "0.8.7"
ConcreteStructs = "0.2.3"
ConstructionBase = "1.5"
DifferentiationInterface = "0.5"
DifferentiationInterface = "0.6"
Enzyme = "0.12, 0.13"
EnzymeCore = "0.7, 0.8"
ExplicitImports = "1.9.0"
Expand Down
56 changes: 22 additions & 34 deletions lib/SciMLJacobianOperators/src/SciMLJacobianOperators.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
module SciMLJacobianOperators

using ADTypes: ADTypes, AutoSparse, AutoEnzyme
using ADTypes: ADTypes, AutoSparse
using ConcreteStructs: @concrete
using ConstructionBase: ConstructionBase
using DifferentiationInterface: DifferentiationInterface
using DifferentiationInterface: DifferentiationInterface, Constant
using EnzymeCore: EnzymeCore
using FastClosures: @closure
using LinearAlgebra: LinearAlgebra
Expand Down Expand Up @@ -112,10 +112,10 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
iip = SciMLBase.isinplace(prob)
T = promote_type(eltype(u), eltype(fu))

vjp_autodiff = set_function_as_const(get_dense_ad(vjp_autodiff))
vjp_autodiff = get_dense_ad(vjp_autodiff)
vjp_op = prepare_vjp(skip_vjp, prob, f, u, fu; autodiff = vjp_autodiff)

jvp_autodiff = set_function_as_const(get_dense_ad(jvp_autodiff))
jvp_autodiff = get_dense_ad(jvp_autodiff)
jvp_op = prepare_jvp(skip_jvp, prob, f, u, fu; autodiff = jvp_autodiff)

output_cache = fu isa Number ? T(fu) : similar(fu, T)
Expand Down Expand Up @@ -295,23 +295,21 @@ function prepare_vjp(::Val{false}, prob::AbstractNonlinearProblem,

@assert autodiff!==nothing "`vjp_autodiff` must be provided if `f` doesn't have \
analytic `vjp` or `jac`."
# TODO: Once DI supports const params we can use `p`
fₚ = SciMLBase.JacobianWrapper{SciMLBase.isinplace(f)}(f, prob.p)
if SciMLBase.isinplace(f)
@assert DI.check_twoarg(autodiff) "Backend: $(autodiff) doesn't support in-place \
problems."
@assert DI.check_inplace(autodiff) "Backend: $(autodiff) doesn't support in-place \
problems."
fu_cache = copy(fu)
v_fake = copy(fu)
di_extras = DI.prepare_pullback(fₚ, fu_cache, autodiff, u, v_fake)
di_extras = DI.prepare_pullback(f, fu_cache, autodiff, u, (fu,), Constant(prob.p))
return @closure (vJ, v, u, p) -> begin
DI.pullback!(fₚ, fu_cache, reshape(vJ, size(u)), autodiff,
u, reshape(v, size(fu_cache)), di_extras)
DI.pullback!(f, fu_cache, (reshape(vJ, size(u)),), di_extras, autodiff,
u, (reshape(v, size(fu_cache)),), Constant(p))
return
end
else
di_extras = DI.prepare_pullback(fₚ, autodiff, u, fu)
di_extras = DI.prepare_pullback(f, autodiff, u, (fu,), Constant(prob.p))
return @closure (v, u, p) -> begin
return DI.pullback(fₚ, autodiff, u, reshape(v, size(fu)), di_extras)
return only(DI.pullback(
f, di_extras, autodiff, u, (reshape(v, size(fu)),), Constant(p)))
end
end
end
Expand Down Expand Up @@ -342,23 +340,21 @@ function prepare_jvp(::Val{false}, prob::AbstractNonlinearProblem,

@assert autodiff!==nothing "`jvp_autodiff` must be provided if `f` doesn't have \
analytic `vjp` or `jac`."
# TODO: Once DI supports const params we can use `p`
fₚ = SciMLBase.JacobianWrapper{SciMLBase.isinplace(f)}(f, prob.p)
if SciMLBase.isinplace(f)
@assert DI.check_twoarg(autodiff) "Backend: $(autodiff) doesn't support in-place \
problems."
@assert DI.check_inplace(autodiff) "Backend: $(autodiff) doesn't support in-place \
problems."
fu_cache = copy(fu)
di_extras = DI.prepare_pushforward(fₚ, fu_cache, autodiff, u, u)
di_extras = DI.prepare_pushforward(f, fu_cache, autodiff, u, (u,), Constant(prob.p))
return @closure (Jv, v, u, p) -> begin
DI.pushforward!(
fₚ, fu_cache, reshape(Jv, size(fu_cache)),
autodiff, u, reshape(v, size(u)), di_extras)
DI.pushforward!(f, fu_cache, (reshape(Jv, size(fu_cache)),), di_extras,
autodiff, u, (reshape(v, size(u)),), Constant(p))
return
end
else
di_extras = DI.prepare_pushforward(fₚ, autodiff, u, u)
di_extras = DI.prepare_pushforward(f, autodiff, u, (u,), Constant(prob.p))
return @closure (v, u, p) -> begin
return DI.pushforward(fₚ, autodiff, u, reshape(v, size(u)), di_extras)
return only(DI.pushforward(
f, di_extras, autodiff, u, (reshape(v, size(u)),), Constant(p)))
end
end
end
Expand All @@ -371,10 +367,8 @@ function prepare_scalar_op(::Val{false}, prob::AbstractNonlinearProblem,

@assert autodiff!==nothing "`autodiff` must be provided if `f` doesn't have \
analytic `vjp` or `jvp` or `jac`."
# TODO: Once DI supports const params we can use `p`
fₚ = Base.Fix2(f, prob.p)
di_extras = DI.prepare_derivative(fₚ, autodiff, u)
return @closure (v, u, p) -> DI.derivative(fₚ, autodiff, u, di_extras) * v
di_extras = DI.prepare_derivative(f, autodiff, u, Constant(prob.p))
return @closure (v, u, p) -> DI.derivative(f, di_extras, autodiff, u, Constant(p)) * v
end

get_dense_ad(::Nothing) = nothing
Expand All @@ -386,12 +380,6 @@ function get_dense_ad(ad::AutoSparse)
return dense_ad
end

# In our case we know that it is safe to mark the function as const
set_function_as_const(ad) = ad
function set_function_as_const(ad::AutoEnzyme{M, Nothing}) where {M}
return AutoEnzyme(; ad.mode, function_annotation = EnzymeCore.Const)
end

export JacobianOperator, VecJacOperator, JacVecOperator
export StatefulJacobianOperator
export StatefulJacobianNormalFormOperator
Expand Down
6 changes: 3 additions & 3 deletions lib/SciMLJacobianOperators/test/core_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
AutoEnzyme(),
AutoEnzyme(; mode = Enzyme.Reverse),
AutoZygote(),
AutoReverseDiff(),
# AutoReverseDiff(), # FIXME: https://github.com/gdalle/DifferentiationInterface.jl/issues/503
AutoTracker(),
AutoFiniteDiff()
]
Expand Down Expand Up @@ -91,7 +91,7 @@ end
reverse_ADs = [
AutoEnzyme(),
AutoEnzyme(; mode = Enzyme.Reverse),
AutoReverseDiff(),
# AutoReverseDiff(), # FIXME: https://github.com/gdalle/DifferentiationInterface.jl/issues/503
AutoFiniteDiff()
]

Expand Down Expand Up @@ -182,7 +182,7 @@ end
AutoEnzyme(; mode = Enzyme.Reverse),
AutoZygote(),
AutoTracker(),
AutoReverseDiff(),
# AutoReverseDiff(), # FIXME: https://github.com/gdalle/DifferentiationInterface.jl/issues/503
AutoFiniteDiff()
]

Expand Down
Loading