diff --git a/DifferentiationInterface/docs/src/api.md b/DifferentiationInterface/docs/src/api.md index 61a75e8fa..c37435907 100644 --- a/DifferentiationInterface/docs/src/api.md +++ b/DifferentiationInterface/docs/src/api.md @@ -93,6 +93,8 @@ hvp! prepare_hessian hessian hessian! +value_gradient_and_hessian +value_gradient_and_hessian! ``` ## Utilities diff --git a/DifferentiationInterface/docs/src/operators.md b/DifferentiationInterface/docs/src/operators.md index 65d2703e4..dde8132c9 100644 --- a/DifferentiationInterface/docs/src/operators.md +++ b/DifferentiationInterface/docs/src/operators.md @@ -45,16 +45,16 @@ These operators are computed using the input `x` and a "seed" `v`, which lives e Several variants of each operator are defined. -| out-of-place | in-place | out-of-place + primal | in-place + primal | -| :-------------------------- | :--------------------------- | :----------------------------------------------- | :----------------------------------------------- | -| [`derivative`](@ref) | [`derivative!`](@ref) | [`value_and_derivative`](@ref) | [`value_and_derivative!`](@ref) | +| out-of-place | in-place | out-of-place + primal | in-place + primal | +| :-------------------------- | :--------------------------- | :----------------------------------------------- | :------------------------------------------------ | +| [`derivative`](@ref) | [`derivative!`](@ref) | [`value_and_derivative`](@ref) | [`value_and_derivative!`](@ref) | | [`second_derivative`](@ref) | [`second_derivative!`](@ref) | [`value_derivative_and_second_derivative`](@ref) | [`value_derivative_and_second_derivative!`](@ref) | -| [`gradient`](@ref) | [`gradient!`](@ref) | [`value_and_gradient`](@ref) | [`value_and_gradient!`](@ref) | -| [`hessian`](@ref) | [`hessian!`](@ref) | NA | NA | -| [`jacobian`](@ref) | [`jacobian!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!`](@ref) | -| [`pushforward`](@ref) | [`pushforward!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!`](@ref) | -| [`pullback`](@ref) | [`pullback!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!`](@ref) | -| [`hvp`](@ref) | [`hvp!`](@ref) | NA | NA | +| [`gradient`](@ref) | [`gradient!`](@ref) | [`value_and_gradient`](@ref) | [`value_and_gradient!`](@ref) | +| [`hessian`](@ref) | [`hessian!`](@ref) | [`value_gradient_and_hessian`](@ref) | [`value_gradient_and_hessian!`](@ref) NA | +| [`jacobian`](@ref) | [`jacobian!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!`](@ref) | +| [`pushforward`](@ref) | [`pushforward!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!`](@ref) | +| [`pullback`](@ref) | [`pullback!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!`](@ref) | +| [`hvp`](@ref) | [`hvp!`](@ref) | NA | NA | ## Mutation and signatures diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl index 499592b03..404547242 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl @@ -10,7 +10,8 @@ using DifferentiationInterface: JacobianExtras, PullbackExtras, PushforwardExtras, - SecondDerivativeExtras + SecondDerivativeExtras, + maybe_dense_ad using FastDifferentiation: derivative, hessian, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl index 084c42dd4..74b3ca8fb 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl @@ -1,9 +1,9 @@ ## Pushforward -struct FastDifferentiationOneArgPushforwardExtras{Y,E1,E2} <: PushforwardExtras +struct FastDifferentiationOneArgPushforwardExtras{Y,E1,E1!} <: PushforwardExtras y_prototype::Y jvp_exe::E1 - jvp_exe!::E2 + jvp_exe!::E1! end function DI.prepare_pushforward(f, ::AutoFastDifferentiation, x, dx) @@ -70,9 +70,9 @@ end ## Pullback -struct FastDifferentiationOneArgPullbackExtras{E1,E2} <: PullbackExtras +struct FastDifferentiationOneArgPullbackExtras{E1,E1!} <: PullbackExtras vjp_exe::E1 - vjp_exe!::E2 + vjp_exe!::E1! end function DI.prepare_pullback(f, ::AutoFastDifferentiation, x, dy) @@ -133,10 +133,10 @@ end ## Derivative -struct FastDifferentiationOneArgDerivativeExtras{Y,E1,E2} <: DerivativeExtras +struct FastDifferentiationOneArgDerivativeExtras{Y,E1,E1!} <: DerivativeExtras y_prototype::Y der_exe::E1 - der_exe!::E2 + der_exe!::E1! end function DI.prepare_derivative(f, ::AutoFastDifferentiation, x) @@ -190,13 +190,12 @@ end ## Gradient -struct FastDifferentiationOneArgGradientExtras{E1,E2} <: GradientExtras +struct FastDifferentiationOneArgGradientExtras{E1,E1!} <: GradientExtras jac_exe::E1 - jac_exe!::E2 + jac_exe!::E1! end function DI.prepare_gradient(f, backend::AutoFastDifferentiation, x) - y_prototype = f(x) x_var = make_variables(:x, size(x)...) y_var = f(x_var) @@ -241,10 +240,10 @@ end ## Jacobian -struct FastDifferentiationOneArgJacobianExtras{Y,E1,E2} <: JacobianExtras +struct FastDifferentiationOneArgJacobianExtras{Y,E1,E1!} <: JacobianExtras y_prototype::Y jac_exe::E1 - jac_exe!::E2 + jac_exe!::E1! end function DI.prepare_jacobian( @@ -307,16 +306,15 @@ end ## Second derivative -struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,E1,E1!,E2,E2!} <: +struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,D,E2,E2!} <: SecondDerivativeExtras y_prototype::Y - der_exe::E1 - der_exe!::E1! + derivative_extras::D der2_exe::E2 der2_exe!::E2! end -function DI.prepare_second_derivative(f, ::AutoFastDifferentiation, x) +function DI.prepare_second_derivative(f, backend::AutoFastDifferentiation, x) y_prototype = f(x) x_var = only(make_variables(:x)) y_var = f(x_var) @@ -324,17 +322,13 @@ function DI.prepare_second_derivative(f, ::AutoFastDifferentiation, x) x_vec_var = monovec(x_var) y_vec_var = y_var isa Number ? monovec(y_var) : vec(y_var) - der_vec_var = derivative(y_vec_var, x_var) der2_vec_var = derivative(y_vec_var, x_var, x_var) - - der_exe = make_function(der_vec_var, x_vec_var; in_place=false) - der_exe! = make_function(der_vec_var, x_vec_var; in_place=true) - der2_exe = make_function(der2_vec_var, x_vec_var; in_place=false) der2_exe! = make_function(der2_vec_var, x_vec_var; in_place=true) + derivative_extras = DI.prepare_derivative(f, backend, x) return FastDifferentiationAllocatingSecondDerivativeExtras( - y_prototype, der_exe, der_exe!, der2_exe, der2_exe! + y_prototype, derivative_extras, der2_exe, der2_exe! ) end @@ -364,20 +358,13 @@ end function DI.value_derivative_and_second_derivative( f, - ::AutoFastDifferentiation, + backend::AutoFastDifferentiation, x, extras::FastDifferentiationAllocatingSecondDerivativeExtras, ) - y = f(x) - if extras.y_prototype isa Number - der = only(extras.der_exe(monovec(x))) - der2 = only(extras.der2_exe(monovec(x))) - return y, der, der2 - else - der = reshape(extras.der_exe(monovec(x)), size(extras.y_prototype)) - der2 = reshape(extras.der2_exe(monovec(x)), size(extras.y_prototype)) - return y, der, der2 - end + y, der = DI.value_and_derivative(f, backend, x, extras.derivative_extras) + der2 = DI.second_derivative(f, backend, x, extras) + return y, der, der2 end function DI.value_derivative_and_second_derivative!( @@ -388,17 +375,16 @@ function DI.value_derivative_and_second_derivative!( x, extras::FastDifferentiationAllocatingSecondDerivativeExtras, ) - y = f(x) - extras.der_exe!(vec(der), monovec(x)) - extras.der2_exe!(vec(der2), monovec(x)) + y, _ = DI.value_and_derivative!(f, der, backend, x, extras.derivative_extras) + DI.second_derivative!(f, der2, backend, x, extras) return y, der, der2 end ## HVP -struct FastDifferentiationHVPExtras{E1,E2} <: HVPExtras - hvp_exe::E1 - hvp_exe!::E2 +struct FastDifferentiationHVPExtras{E2,E2!} <: HVPExtras + hvp_exe::E2 + hvp_exe!::E2! end function DI.prepare_hvp(f, ::AutoFastDifferentiation, x, v) @@ -428,24 +414,30 @@ end ## Hessian -struct FastDifferentiationHessianExtras{E1,E2} <: HessianExtras - hess_exe::E1 - hess_exe!::E2 +struct FastDifferentiationHessianExtras{G,E2,E2!} <: HessianExtras + gradient_extras::G + hess_exe::E2 + hess_exe!::E2! end function DI.prepare_hessian( f, backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, x ) - x_vec_var = make_variables(:x, size(x)...) - y_vec_var = f(x_vec_var) + x_var = make_variables(:x, size(x)...) + y_var = f(x_var) + + x_vec_var = vec(x_var) + hess_var = if backend isa AutoSparse - sparse_hessian(y_vec_var, vec(x_vec_var)) + sparse_hessian(y_var, x_vec_var) else - hessian(y_vec_var, vec(x_vec_var)) + hessian(y_var, x_vec_var) end - hess_exe = make_function(hess_var, vec(x_vec_var); in_place=false) - hess_exe! = make_function(hess_var, vec(x_vec_var); in_place=true) - return FastDifferentiationHessianExtras(hess_exe, hess_exe!) + hess_exe = make_function(hess_var, x_vec_var; in_place=false) + hess_exe! = make_function(hess_var, x_vec_var; in_place=true) + + gradient_extras = DI.prepare_gradient(f, maybe_dense_ad(backend), x) + return FastDifferentiationHessianExtras(gradient_extras, hess_exe, hess_exe!) end function DI.hessian( @@ -467,3 +459,29 @@ function DI.hessian!( extras.hess_exe!(hess, vec(x)) return hess end + +function DI.value_gradient_and_hessian( + f, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + x, + extras::FastDifferentiationHessianExtras, +) + y, grad = DI.value_and_gradient(f, maybe_dense_ad(backend), x, extras.gradient_extras) + hess = DI.hessian(f, backend, x, extras) + return y, grad, hess +end + +function DI.value_gradient_and_hessian!( + f, + grad, + hess, + backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}}, + x, + extras::FastDifferentiationHessianExtras, +) + y, _ = DI.value_and_gradient!( + f, grad, maybe_dense_ad(backend), x, extras.gradient_extras + ) + DI.hessian!(f, hess, backend, x, extras) + return y, grad, hess +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl index b3141f506..5d7059de8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl @@ -1,8 +1,8 @@ ## Pushforward -struct FastDifferentiationTwoArgPushforwardExtras{E1,E2} <: PushforwardExtras +struct FastDifferentiationTwoArgPushforwardExtras{E1,E1!} <: PushforwardExtras jvp_exe::E1 - jvp_exe!::E2 + jvp_exe!::E1! end function DI.prepare_pushforward(f!, y, ::AutoFastDifferentiation, x, dx) @@ -80,9 +80,9 @@ end ## Pullback -struct FastDifferentiationTwoArgPullbackExtras{E1,E2} <: PullbackExtras +struct FastDifferentiationTwoArgPullbackExtras{E1,E1!} <: PullbackExtras vjp_exe::E1 - vjp_exe!::E2 + vjp_exe!::E1! end function DI.prepare_pullback(f!, y, ::AutoFastDifferentiation, x, dy) @@ -156,9 +156,9 @@ end ## Derivative -struct FastDifferentiationTwoArgDerivativeExtras{E1,E2} <: DerivativeExtras +struct FastDifferentiationTwoArgDerivativeExtras{E1,E1!} <: DerivativeExtras der_exe::E1 - der_exe!::E2 + der_exe!::E1! end function DI.prepare_derivative(f!, y, ::AutoFastDifferentiation, x) @@ -216,9 +216,9 @@ end ## Jacobian -struct FastDifferentiationTwoArgJacobianExtras{E1,E2} <: JacobianExtras +struct FastDifferentiationTwoArgJacobianExtras{E1,E1!} <: JacobianExtras jac_exe::E1 - jac_exe!::E2 + jac_exe!::E1! end function DI.prepare_jacobian( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl index 8f30c94be..96a9e2f00 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl @@ -153,21 +153,39 @@ end ## Hessian -struct FiniteDiffHessianExtras{C} <: HessianExtras - cache::C +struct FiniteDiffHessianExtras{C1,C2} <: HessianExtras + gradient_cache::C1 + hessian_cache::C2 end function DI.prepare_hessian(f, backend::AutoFiniteDiff, x) - cache = HessianCache(x, fdhtype(backend)) - return FiniteDiffHessianExtras(cache) + y = f(x) + df = zero(y) .* x + gradient_cache = GradientCache(df, x, fdtype(backend)) + hessian_cache = HessianCache(x, fdhtype(backend)) + return FiniteDiffHessianExtras(gradient_cache, hessian_cache) end -# cache cannot be reused because of https://github.com/JuliaDiff/FiniteDiff.jl/issues/185 - function DI.hessian(f, backend::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras) - return finite_difference_hessian(f, x, extras.cache) + return finite_difference_hessian(f, x, extras.hessian_cache) end function DI.hessian!(f, hess, backend::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras) - return finite_difference_hessian!(hess, f, x, extras.cache) + return finite_difference_hessian!(hess, f, x, extras.hessian_cache) +end + +function DI.value_gradient_and_hessian( + f, backend::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras +) + grad = finite_difference_gradient(f, x, extras.gradient_cache) + hess = finite_difference_hessian(f, x, extras.hessian_cache) + return f(x), grad, hess +end + +function DI.value_gradient_and_hessian!( + f, grad, hess, backend::AutoFiniteDiff, x, extras::FiniteDiffHessianExtras +) + finite_difference_gradient!(grad, f, x, extras.gradient_cache) + finite_difference_hessian!(hess, f, x, extras.hessian_cache) + return f(x), grad, hess end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index 6326cbc4f..ad4aa0c61 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -137,20 +137,49 @@ end ## Hessian -struct ForwardDiffHessianExtras{C} <: HessianExtras - config::C +struct ForwardDiffHessianExtras{C1,C2} <: HessianExtras + array_config::C1 + result_config::C2 end function DI.prepare_hessian(f, backend::AutoForwardDiff, x) - return ForwardDiffHessianExtras(HessianConfig(f, x, choose_chunk(backend, x))) + example_result = MutableDiffResult( + one(eltype(x)), (similar(x), similar(x, length(x), length(x))) + ) + chunk = choose_chunk(backend, x) + array_config = HessianConfig(f, x, chunk) + result_config = HessianConfig(f, example_result, x, chunk) + return ForwardDiffHessianExtras(array_config, result_config) end function DI.hessian!( f::F, hess, ::AutoForwardDiff, x, extras::ForwardDiffHessianExtras ) where {F} - return hessian!(hess, f, x, extras.config) + return hessian!(hess, f, x, extras.array_config) end function DI.hessian(f::F, ::AutoForwardDiff, x, extras::ForwardDiffHessianExtras) where {F} - return hessian(f, x, extras.config) + return hessian(f, x, extras.array_config) +end + +function DI.value_gradient_and_hessian!( + f::F, grad, hess, ::AutoForwardDiff, x, extras::ForwardDiffHessianExtras +) where {F} + result = MutableDiffResult(one(eltype(x)), (grad, hess)) + result = hessian!(result, f, x, extras.result_config) + return ( + DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) + ) +end + +function DI.value_gradient_and_hessian( + f::F, ::AutoForwardDiff, x, extras::ForwardDiffHessianExtras +) where {F} + result = MutableDiffResult( + one(eltype(x)), (similar(x), similar(x, length(x), length(x))) + ) + result = hessian!(result, f, x, extras.result_config) + return ( + DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) + ) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index 672bdb963..99cfa2f2a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -42,9 +42,9 @@ function DI.value_and_derivative( end function DI.value_and_derivative!( - f, dy, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras + f, der, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras ) - return DI.value_and_derivative!(f, dy, single_threaded(backend), x, extras) + return DI.value_and_derivative!(f, der, single_threaded(backend), x, extras) end function DI.derivative(f, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras) @@ -52,9 +52,9 @@ function DI.derivative(f, backend::AutoPolyesterForwardDiff, x, extras::Derivati end function DI.derivative!( - f, dy, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras + f, der, backend::AutoPolyesterForwardDiff, x, extras::DerivativeExtras ) - return DI.derivative!(f, dy, single_threaded(backend), x, extras) + return DI.derivative!(f, der, single_threaded(backend), x, extras) end ## Gradient @@ -149,6 +149,20 @@ function DI.hessian(f, backend::AutoPolyesterForwardDiff, x, extras::HessianExtr return DI.hessian(f, single_threaded(backend), x, extras) end -function DI.hessian!(f, dy, backend::AutoPolyesterForwardDiff, x, extras::HessianExtras) - return DI.hessian!(f, dy, single_threaded(backend), x, extras) +function DI.hessian!(f, hess, backend::AutoPolyesterForwardDiff, x, extras::HessianExtras) + return DI.hessian!(f, hess, single_threaded(backend), x, extras) +end + +function DI.value_gradient_and_hessian( + f, backend::AutoPolyesterForwardDiff, x, extras::HessianExtras +) + return DI.value_gradient_and_hessian(f, single_threaded(backend), x, extras) +end + +function DI.value_gradient_and_hessian!( + f, grad, hess, backend::AutoPolyesterForwardDiff, x, extras::HessianExtras +) + return DI.value_gradient_and_hessian!( + f, grad, hess, single_threaded(backend), x, extras + ) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index c940db730..89d091b51 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -163,3 +163,30 @@ function DI.hessian( ) return hessian!(extras.tape, x) end + +function DI.value_gradient_and_hessian!( + _f, + grad, + hess::AbstractMatrix, + ::AutoReverseDiff, + x::AbstractArray, + extras::ReverseDiffHessianExtras, +) + result = MutableDiffResult(one(eltype(x)), (grad, hess)) + result = hessian!(result, extras.tape, x) + return ( + DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) + ) +end + +function DI.value_gradient_and_hessian( + _f, ::AutoReverseDiff, x::AbstractArray, extras::ReverseDiffHessianExtras +) + result = MutableDiffResult( + one(eltype(x)), (similar(x), similar(x, length(x), length(x))) + ) + result = hessian!(result, extras.tape, x) + return ( + DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result) + ) +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl index 0074c6d79..b99cd1d24 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl @@ -10,7 +10,8 @@ using DifferentiationInterface: JacobianExtras, PullbackExtras, PushforwardExtras, - SecondDerivativeExtras + SecondDerivativeExtras, + maybe_dense_ad using FillArrays: Fill using LinearAlgebra: dot using Symbolics: diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index da809b60f..a816bfddc 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -1,8 +1,8 @@ ## Pushforward -struct SymbolicsOneArgPushforwardExtras{E1,E2} <: PushforwardExtras +struct SymbolicsOneArgPushforwardExtras{E1,E1!} <: PushforwardExtras pf_exe::E1 - pf_exe!::E2 + pf_exe!::E1! end function DI.prepare_pushforward(f, ::AutoSymbolics, x, dx) @@ -57,9 +57,9 @@ end ## Derivative -struct SymbolicsOneArgDerivativeExtras{E1,E2} <: DerivativeExtras +struct SymbolicsOneArgDerivativeExtras{E1,E1!} <: DerivativeExtras der_exe::E1 - der_exe!::E2 + der_exe!::E1! end function DI.prepare_derivative(f, ::AutoSymbolics, x) @@ -98,9 +98,9 @@ end ## Gradient -struct SymbolicsOneArgGradientExtras{E1,E2} <: GradientExtras +struct SymbolicsOneArgGradientExtras{E1,E1!} <: GradientExtras grad_exe::E1 - grad_exe!::E2 + grad_exe!::E1! end function DI.prepare_gradient(f, ::AutoSymbolics, x) @@ -136,9 +136,9 @@ end ## Jacobian -struct SymbolicsOneArgJacobianExtras{E1,E2} <: JacobianExtras +struct SymbolicsOneArgJacobianExtras{E1,E1!} <: JacobianExtras jac_exe::E1 - jac_exe!::E2 + jac_exe!::E1! end function DI.prepare_jacobian( @@ -197,9 +197,10 @@ end ## Hessian -struct SymbolicsOneArgHessianExtras{E1,E2} <: HessianExtras - hess_exe::E1 - hess_exe!::E2 +struct SymbolicsOneArgHessianExtras{G,E2,E2!} <: HessianExtras + gradient_extras::G + hess_exe::E2 + hess_exe!::E2! end function DI.prepare_hessian(f, backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, x) @@ -213,7 +214,9 @@ function DI.prepare_hessian(f, backend::Union{AutoSymbolics,AutoSparse{<:AutoSym res = build_function(hess_var, vec(x_var); expression=Val(false)) (hess_exe, hess_exe!) = res - return SymbolicsOneArgHessianExtras(hess_exe, hess_exe!) + + gradient_extras = DI.prepare_gradient(f, maybe_dense_ad(backend), x) + return SymbolicsOneArgHessianExtras(gradient_extras, hess_exe, hess_exe!) end function DI.hessian( @@ -235,3 +238,29 @@ function DI.hessian!( extras.hess_exe!(hess, vec(x)) return hess end + +function DI.value_gradient_and_hessian( + f, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + x, + extras::SymbolicsOneArgHessianExtras, +) + y, grad = DI.value_and_gradient(f, maybe_dense_ad(backend), x, extras.gradient_extras) + hess = DI.hessian(f, backend, x, extras) + return y, grad, hess +end + +function DI.value_gradient_and_hessian!( + f, + grad, + hess, + backend::Union{AutoSymbolics,AutoSparse{<:AutoSymbolics}}, + x, + extras::SymbolicsOneArgHessianExtras, +) + y, _ = DI.value_and_gradient!( + f, grad, maybe_dense_ad(backend), x, extras.gradient_extras + ) + DI.hessian!(f, hess, backend, x, extras) + return y, grad, hess +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl index 14ee93f96..a6e37f6df 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl @@ -1,8 +1,8 @@ ## Pushforward -struct SymbolicsTwoArgPushforwardExtras{E1,E2} <: PushforwardExtras +struct SymbolicsTwoArgPushforwardExtras{E1,E1!} <: PushforwardExtras pushforward_exe::E1 - pushforward_exe!::E2 + pushforward_exe!::E1! end function DI.prepare_pushforward(f!, y, ::AutoSymbolics, x, dx) @@ -61,9 +61,9 @@ end ## Derivative -struct SymbolicsTwoArgDerivativeExtras{E1,E2} <: DerivativeExtras +struct SymbolicsTwoArgDerivativeExtras{E1,E1!} <: DerivativeExtras der_exe::E1 - der_exe!::E2 + der_exe!::E1! end function DI.prepare_derivative(f!, y, ::AutoSymbolics, x) @@ -106,9 +106,9 @@ end ## Jacobian -struct SymbolicsTwoArgJacobianExtras{E1,E2} <: JacobianExtras +struct SymbolicsTwoArgJacobianExtras{E1,E1!} <: JacobianExtras jac_exe::E1 - jac_exe!::E2 + jac_exe!::E1! end function DI.prepare_jacobian( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index 04d375850..2ce4ba6cb 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -131,4 +131,18 @@ function DI.hessian!(f, hess, backend::AutoZygote, x, extras::NoHessianExtras) return copyto!(hess, DI.hessian(f, backend, x, extras)) end +function DI.value_gradient_and_hessian(f, backend::AutoZygote, x, extras::NoHessianExtras) + y, grad = DI.value_and_gradient(f, backend, x, NoGradientExtras()) + hess = DI.hessian(f, backend, x, extras) + return y, grad, hess +end + +function DI.value_gradient_and_hessian!( + f, grad, hess, backend::AutoZygote, x, extras::NoHessianExtras +) + y, _ = DI.value_and_gradient!(f, grad, backend, x, NoGradientExtras()) + DI.hessian!(f, hess, backend, x, extras) + return y, grad, hess +end + end diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 73b04c890..504171bbe 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -54,6 +54,7 @@ include("utils/printing.jl") include("utils/chunk.jl") include("utils/check.jl") include("utils/exceptions.jl") +include("utils/maybe.jl") include("first_order/pushforward.jl") include("first_order/pullback.jl") @@ -99,6 +100,7 @@ export second_derivative!, second_derivative export value_derivative_and_second_derivative, value_derivative_and_second_derivative! export hvp!, hvp export hessian!, hessian +export value_gradient_and_hessian, value_gradient_and_hessian! export prepare_pushforward, prepare_pushforward_same_point export prepare_pullback, prepare_pullback_same_point diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 7d083297c..267492eb1 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -24,6 +24,20 @@ Compute the Hessian matrix of the function `f` at point `x`, overwriting `hess`. """ function hessian! end +""" + value_gradient_and_hessian(f, backend, x, [extras]) -> (y, grad, hess) + +Compute the value, gradient vector and Hessian matrix of the function `f` at point `x`. +""" +function value_gradient_and_hessian end + +""" + value_gradient_and_hessian!(f, grad, hess, backend, x, [extras]) -> (y, grad, hess) + +Compute the value, gradient vector and Hessian matrix of the function `f` at point `x`, overwriting `grad` and `hess`. +""" +function value_gradient_and_hessian! end + ## Preparation """ @@ -35,30 +49,22 @@ abstract type HessianExtras <: Extras end struct NoHessianExtras <: HessianExtras end -struct HVPHessianExtras{E<:HVPExtras} <: HessianExtras - hvp_extras::E +struct HVPGradientHessianExtras{E2<:HVPExtras,E1<:GradientExtras} <: HessianExtras + hvp_extras::E2 + gradient_extras::E1 end function prepare_hessian(f::F, backend::AbstractADType, x) where {F} - return prepare_hessian(f, SecondOrder(backend, backend), x) -end - -function prepare_hessian(f::F, backend::SecondOrder, x) where {F} v = basis(backend, x, first(CartesianIndices(x))) hvp_extras = prepare_hvp(f, backend, x, v) - return HVPHessianExtras(hvp_extras) + gradient_extras = prepare_gradient(f, maybe_inner(backend), x) + return HVPGradientHessianExtras(hvp_extras, gradient_extras) end ## One argument function hessian( f::F, backend::AbstractADType, x, extras::HessianExtras=prepare_hessian(f, backend, x) -) where {F} - return hessian(f, SecondOrder(backend, backend), x, extras) -end - -function hessian( - f::F, backend::SecondOrder, x, extras::HessianExtras=prepare_hessian(f, backend, x) ) where {F} hvp_extras_same = prepare_hvp_same_point( f, backend, x, basis(backend, x, first(CartesianIndices(x))), extras.hvp_extras @@ -76,16 +82,6 @@ function hessian!( backend::AbstractADType, x, extras::HessianExtras=prepare_hessian(f, backend, x), -) where {F} - return hessian!(f, hess, SecondOrder(backend, backend), x, extras) -end - -function hessian!( - f::F, - hess, - backend::SecondOrder, - x, - extras::HessianExtras=prepare_hessian(f, backend, x), ) where {F} hvp_extras_same = prepare_hvp_same_point( f, backend, x, basis(backend, x, first(CartesianIndices(x))), extras.hvp_extras @@ -96,3 +92,24 @@ function hessian!( end return hess end + +function value_gradient_and_hessian( + f::F, backend::AbstractADType, x, extras::HessianExtras=prepare_hessian(f, backend, x) +) where {F} + y, grad = value_and_gradient(f, maybe_inner(backend), x, extras.gradient_extras) + hess = hessian(f, backend, x, extras) + return y, grad, hess +end + +function value_gradient_and_hessian!( + f::F, + grad, + hess, + backend::AbstractADType, + x, + extras::HessianExtras=prepare_hessian(f, backend, x), +) where {F} + y, _ = value_and_gradient!(f, grad, maybe_inner(backend), x, extras.gradient_extras) + hessian!(f, hess, backend, x, extras) + return y, grad, hess +end diff --git a/DifferentiationInterface/src/sparse/hessian.jl b/DifferentiationInterface/src/sparse/hessian.jl index 222f773cc..074c6fdf0 100644 --- a/DifferentiationInterface/src/sparse/hessian.jl +++ b/DifferentiationInterface/src/sparse/hessian.jl @@ -4,14 +4,16 @@ Base.@kwdef struct SparseHessianExtras{ K<:AbstractVector{<:Integer}, D<:AbstractVector, P<:AbstractVector, - E<:Extras, + E2<:HVPExtras, + E1<:GradientExtras, } <: HessianExtras sparsity::S compressed::C colors::K seeds::D products::P - hvp_extras::E + hvp_extras::E2 + gradient_extras::E1 end ## Hessian, one argument @@ -32,7 +34,10 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F} similar(x) end compressed = stack(vec, products; dims=2) - return SparseHessianExtras(; sparsity, compressed, colors, seeds, products, hvp_extras) + gradient_extras = prepare_gradient(f, maybe_inner(dense_backend), x) + return SparseHessianExtras(; + sparsity, compressed, colors, seeds, products, hvp_extras, gradient_extras + ) end function hessian!(f::F, hess, backend::AutoSparse, x, extras::SparseHessianExtras) where {F} @@ -56,3 +61,23 @@ function hessian(f::F, backend::AutoSparse, x, extras::SparseHessianExtras) wher end return decompress_symmetric(sparsity, compressed, colors) end + +function value_gradient_and_hessian!( + f::F, grad, hess, backend::AutoSparse, x, extras::SparseHessianExtras +) where {F} + y, _ = value_and_gradient!( + f, grad, maybe_inner(dense_ad(backend)), x, extras.gradient_extras + ) + hessian!(f, hess, backend, x, extras) + return y, grad, hess +end + +function value_gradient_and_hessian( + f::F, backend::AutoSparse, x, extras::SparseHessianExtras +) where {F} + y, grad = value_and_gradient( + f, maybe_inner(dense_ad(backend)), x, extras.gradient_extras + ) + hess = hessian(f, backend, x, extras) + return y, grad, hess +end diff --git a/DifferentiationInterface/src/utils/maybe.jl b/DifferentiationInterface/src/utils/maybe.jl new file mode 100644 index 000000000..6b22961f9 --- /dev/null +++ b/DifferentiationInterface/src/utils/maybe.jl @@ -0,0 +1,7 @@ +maybe_inner(backend::SecondOrder) = inner(backend) +maybe_outer(backend::SecondOrder) = outer(backend) +maybe_inner(backend::AbstractADType) = backend +maybe_outer(backend::AbstractADType) = backend + +maybe_dense_ad(backend::AutoSparse) = dense_ad(backend) +maybe_dense_ad(backend::AbstractADType) = backend diff --git a/DifferentiationInterface/test/runtests.jl b/DifferentiationInterface/test/runtests.jl index 89211b9d9..a4910e175 100644 --- a/DifferentiationInterface/test/runtests.jl +++ b/DifferentiationInterface/test/runtests.jl @@ -33,7 +33,7 @@ ALL_BACKENDS = [ @testset verbose = true "DifferentiationInterface.jl" begin if GROUP == "Formalities" || GROUP == "All" @testset "Formalities/$file" for file in readdir(joinpath(@__DIR__, "Formalities")) - @info "Testing Formalities/$file)" + @info "Testing Formalities/$file" include(joinpath(@__DIR__, "Formalities", file)) end end diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index f722f31e9..a8686ddc0 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -24,6 +24,8 @@ using DifferentiationInterface using DifferentiationInterface: backend_str, inner, + maybe_inner, + maybe_dense_ad, mode, outer, twoarg_support, diff --git a/DifferentiationInterfaceTest/src/tests/benchmark.jl b/DifferentiationInterfaceTest/src/tests/benchmark.jl index 4888ef5c5..61994231b 100644 --- a/DifferentiationInterfaceTest/src/tests/benchmark.jl +++ b/DifferentiationInterfaceTest/src/tests/benchmark.jl @@ -975,27 +975,31 @@ function run_benchmark!( logging::Bool, ) @compat (; f, x, y) = deepcopy(scen) - @compat (; bench0, bench1, calls0, calls1) = try + @compat (; bench0, bench1, bench2, calls0, calls1, calls2) = try # benchmark extras = prepare_hessian(f, ba, x) bench0 = @be prepare_hessian(f, ba, x) samples = 1 evals = 1 bench1 = @be deepcopy(extras) hessian(f, ba, x, _) + bench2 = @be deepcopy(extras) value_gradient_and_hessian(f, ba, x, _) # count cc = CallCounter(f) extras = prepare_hessian(cc, ba, x) calls0 = reset_count!(cc) hessian(cc, ba, x, extras) calls1 = reset_count!(cc) - (; bench0, bench1, calls0, calls1) + value_gradient_and_hessian(cc, ba, x, extras) + calls2 = reset_count!(cc) + (; bench0, bench1, bench2, calls0, calls1, calls2) catch e logging && @warn "Error during benchmarking" ba scen e - bench0, bench1 = failed_benchs(2) - calls0, calls1 = -1, -1 - (; bench0, bench1, calls0, calls1) + bench0, bench1, bench2 = failed_benchs(3) + calls0, calls1, calls2 = -1, -1, -1 + (; bench0, bench1, bench2, calls0, calls1, calls2) end # record record!(data, ba, scen, :prepare_hessian, bench0, calls0) record!(data, ba, scen, :hessian, bench1, calls1) + record!(data, ba, scen, :value_gradient_and_hessian, bench2, calls2) return nothing end @@ -1006,7 +1010,7 @@ function run_benchmark!( logging::Bool, ) @compat (; f, x, y) = deepcopy(scen) - @compat (; bench0, bench1, calls0, calls1) = try + @compat (; bench0, bench1, bench2, calls0, calls1, calls2) = try hess_template = Matrix{typeof(y)}(undef, length(x), length(x)) # benchmark extras = prepare_hessian(f, ba, x) @@ -1014,21 +1018,29 @@ function run_benchmark!( bench1 = @be (hess=mysimilar(hess_template), ext=deepcopy(extras)) hessian!( f, _.hess, ba, x, _.ext ) evals = 1 + bench2 = @be ( + grad=mysimilar(x), hess=mysimilar(hess_template), ext=deepcopy(extras) + ) value_gradient_and_hessian!(f, _.grad, _.hess, ba, x, _.ext) evals = 1 # count cc = CallCounter(f) extras = prepare_hessian(cc, ba, x) calls0 = reset_count!(cc) hessian!(cc, mysimilar(hess_template), ba, x, extras) calls1 = reset_count!(cc) - (; bench0, bench1, calls0, calls1) + value_gradient_and_hessian!( + cc, mysimilar(x), mysimilar(hess_template), ba, x, extras + ) + calls2 = reset_count!(cc) + (; bench0, bench1, bench2, calls0, calls1, calls2) catch e logging && @warn "Error during benchmarking" ba scen e - bench0, bench1 = failed_benchs(2) - calls0, calls1 = -1, -1 - (; bench0, bench1, calls0, calls1) + bench0, bench1, bench2 = failed_benchs(3) + calls0, calls1, calls2 = -1, -1, -1 + (; bench0, bench1, bench2, calls0, calls1, calls2) end # record record!(data, ba, scen, :prepare_hessian, bench0, calls0) record!(data, ba, scen, :hessian!, bench1, calls1) + record!(data, ba, scen, :value_gradient_and_hessian!, bench2, calls2) return nothing end diff --git a/DifferentiationInterfaceTest/src/tests/correctness.jl b/DifferentiationInterfaceTest/src/tests/correctness.jl index b443a79e2..e83ebadc5 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness.jl @@ -794,10 +794,8 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) extras = prepare_second_derivative(f, ba, mycopy_random(x)) - der1_true = if ref_backend isa SecondOrder - derivative(f, inner(ref_backend), x) - elseif ref_backend isa AbstractADType - derivative(f, ref_backend, x) + der1_true = if ref_backend isa AbstractADType + derivative(f, maybe_inner(ref_backend), x) else new_scen.first_order_ref(x) end @@ -839,10 +837,8 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) extras = prepare_second_derivative(f, ba, mycopy_random(x)) - der1_true = if ref_backend isa SecondOrder - derivative(f, inner(ref_backend), x) - elseif ref_backend isa AbstractADType - derivative(f, ref_backend, x) + der1_true = if ref_backend isa AbstractADType + derivative(f, maybe_inner(ref_backend), x) else new_scen.first_order_ref(x) end @@ -972,6 +968,11 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) extras = prepare_hessian(f, ba, mycopy_random(x)) + grad_true = if ref_backend isa AbstractADType + gradient(f, maybe_dense_ad(maybe_inner(ref_backend)), x) + else + new_scen.first_order_ref(x) + end hess_true = if ref_backend isa AbstractADType hessian(f, ref_backend, x) else @@ -979,13 +980,21 @@ function test_correctness( end hess1 = hessian(f, ba, x, extras) + y2, grad2, hess2 = value_gradient_and_hessian(f, ba, x, extras) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin @test extras isa HessianExtras end + @testset "Primal value" begin + @test y2 ≈ y + end + @testset "Gradient value" begin + @test grad2 ≈ grad_true + end @testset "Hessian value" begin @test hess1 ≈ hess_true + @test hess2 ≈ hess_true end end test_scen_intact(new_scen, scen) @@ -1002,6 +1011,11 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) extras = prepare_hessian(f, ba, mycopy_random(x)) + grad_true = if ref_backend isa AbstractADType + gradient(f, maybe_dense_ad(maybe_inner(ref_backend)), x) + else + new_scen.first_order_ref(x) + end hess_true = if ref_backend isa AbstractADType hessian(f, ref_backend, x) else @@ -1010,14 +1024,25 @@ function test_correctness( hess1_in = mysimilar(hess_true) hess1 = hessian!(f, hess1_in, ba, x, extras) + grad2_in, hess2_in = mysimilar(grad_true), mysimilar(hess_true) + y2, grad2, hess2 = value_gradient_and_hessian!(f, grad2_in, hess2_in, ba, x, extras) let (≈)(x, y) = isapprox(x, y; atol, rtol) @testset "Extras type" begin @test extras isa HessianExtras end + @testset "Primal value" begin + @test y2 ≈ y + end + @testset "Gradient value" begin + @test grad2_in ≈ grad_true + @test grad2 ≈ grad_true + end @testset "Hessian value" begin @test hess1_in ≈ hess_true + @test hess2_in ≈ hess_true @test hess1 ≈ hess_true + @test hess2 ≈ hess_true end end test_scen_intact(new_scen, scen) diff --git a/DifferentiationInterfaceTest/src/tests/sparsity.jl b/DifferentiationInterfaceTest/src/tests/sparsity.jl index 492fbc734..f129015d9 100644 --- a/DifferentiationInterfaceTest/src/tests/sparsity.jl +++ b/DifferentiationInterfaceTest/src/tests/sparsity.jl @@ -99,9 +99,11 @@ function test_sparsity( end hess1 = hessian(f, ba, x, extras) + _, _, hess2 = value_gradient_and_hessian(f, ba, x, extras) @testset "Sparsity pattern" begin @test mynnz(hess1) == mynnz(hess_true) + @test mynnz(hess2) == mynnz(hess_true) end return nothing end @@ -116,9 +118,13 @@ function test_sparsity(ba::AbstractADType, scen::HessianScenario{1,:inplace}; re end hess1 = hessian!(f, mysimilar(hess_true), ba, x, extras) + _, _, hess2 = value_gradient_and_hessian!( + f, mysimilar(x), mysimilar(hess_true), ba, x, extras + ) @testset "Sparsity pattern" begin @test mynnz(hess1) == mynnz(hess_true) + @test mynnz(hess2) == mynnz(hess_true) end return nothing end diff --git a/DifferentiationInterfaceTest/src/tests/type_stability.jl b/DifferentiationInterfaceTest/src/tests/type_stability.jl index 60e6358a1..e87ee6acd 100644 --- a/DifferentiationInterfaceTest/src/tests/type_stability.jl +++ b/DifferentiationInterfaceTest/src/tests/type_stability.jl @@ -283,14 +283,19 @@ function test_jet(ba::AbstractADType, scen::HessianScenario{1,:outofplace}; ref_ extras = prepare_hessian(f, ba, x) JET.@test_opt function_filter = filt hessian(f, ba, x, extras) + JET.@test_opt function_filter = filt value_gradient_and_hessian(f, ba, x, extras) return nothing end function test_jet(ba::AbstractADType, scen::HessianScenario{1,:inplace}; ref_backend) @compat (; f, x, y) = deepcopy(scen) extras = prepare_hessian(f, ba, x) + grad_in = mysimilar(x) hess_in = Matrix{typeof(y)}(undef, length(x), length(x)) JET.@test_opt function_filter = filt hessian!(f, hess_in, ba, x, extras) + JET.@test_opt function_filter = filt value_gradient_and_hessian!( + f, grad_in, hess_in, ba, x, extras + ) return nothing end