diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index a2a4b6d..7574af7 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -30,7 +30,7 @@ function OptimizationBase.instantiate_function( adtype, soadtype = OptimizationBase.generate_adtype(adtype) if g == true && f.grad === nothing - prep_grad = prepare_gradient(f.f, adtype, x, Constant(p)) + prep_grad = prepare_gradient(f.f, adtype, x, Constant(p), strict=Val(false)) function grad(res, θ) gradient!(f.f, res, prep_grad, adtype, θ, Constant(p)) end @@ -47,7 +47,7 @@ function OptimizationBase.instantiate_function( if fg == true && f.fg === nothing if g == false - prep_grad = prepare_gradient(f.f, adtype, x, Constant(p)) + prep_grad = prepare_gradient(f.f, adtype, x, Constant(p), strict=Val(false)) end function fg!(res, θ) (y, _) = value_and_gradient!(f.f, res, prep_grad, adtype, θ, Constant(p)) @@ -68,7 +68,7 @@ function OptimizationBase.instantiate_function( hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec if h == true && f.hess === nothing - prep_hess = prepare_hessian(f.f, soadtype, x, Constant(p)) + prep_hess = prepare_hessian(f.f, soadtype, x, Constant(p), strict=Val(false)) function hess(res, θ) hessian!(f.f, res, prep_hess, soadtype, θ, Constant(p)) end @@ -143,7 +143,7 @@ function OptimizationBase.instantiate_function( cons_jac_prototype = f.cons_jac_prototype cons_jac_colorvec = f.cons_jac_colorvec if cons !== nothing && cons_j == true && f.cons_j === nothing - prep_jac = prepare_jacobian(cons_oop, adtype, x) + prep_jac = prepare_jacobian(cons_oop, adtype, x, strict=Val(false)) function cons_j!(J, θ) jacobian!(cons_oop, J, prep_jac, adtype, θ) if size(J, 1) == 1 @@ -157,7 +157,7 @@ function OptimizationBase.instantiate_function( end if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing - prep_pullback = prepare_pullback(cons_oop, adtype, x, (ones(eltype(x), num_cons),)) + prep_pullback = prepare_pullback(cons_oop, adtype, x, (ones(eltype(x), num_cons),), strict=Val(false)) function cons_vjp!(J, θ, v) pullback!(cons_oop, (J,), prep_pullback, adtype, θ, (v,)) end @@ -169,7 +169,7 @@ function OptimizationBase.instantiate_function( if cons !== nothing && f.cons_jvp === nothing && cons_jvp == true prep_pushforward = prepare_pushforward( - cons_oop, adtype, x, (ones(eltype(x), length(x)),)) + cons_oop, adtype, x, (ones(eltype(x), length(x)),), strict=Val(false)) function cons_jvp!(J, θ, v) pushforward!(cons_oop, (J,), prep_pushforward, adtype, θ, (v,)) end @@ -182,7 +182,7 @@ function OptimizationBase.instantiate_function( conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec if cons !== nothing && cons_h == true && f.cons_h === nothing - prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i)) + prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i), strict=Val(false)) for i in 1:num_cons] function cons_h!(H, θ) @@ -201,7 +201,7 @@ function OptimizationBase.instantiate_function( if f.lag_h === nothing && cons !== nothing && lag_h == true lag_extras = prepare_hessian( lagrangian, soadtype, x, Constant(one(eltype(x))), - Constant(ones(eltype(x), num_cons)), Constant(p)) + Constant(ones(eltype(x), num_cons)), Constant(p), strict=Val(false)) lag_hess_prototype = zeros(Bool, num_cons, length(x)) function lag_h!(H::AbstractMatrix, θ, σ, λ) @@ -294,7 +294,7 @@ function OptimizationBase.instantiate_function( adtype, soadtype = OptimizationBase.generate_sparse_adtype(adtype) if g == true && f.grad === nothing - extras_grad = prepare_gradient(f.f, adtype.dense_ad, x, Constant(p)) + extras_grad = prepare_gradient(f.f, adtype.dense_ad, x, Constant(p), strict=Val(false)) function grad(res, θ) gradient!(f.f, res, extras_grad, adtype.dense_ad, θ, Constant(p)) end @@ -311,7 +311,7 @@ function OptimizationBase.instantiate_function( if fg == true && f.fg === nothing if g == false - extras_grad = prepare_gradient(f.f, adtype.dense_ad, x, Constant(p)) + extras_grad = prepare_gradient(f.f, adtype.dense_ad, x, Constant(p), strict=Val(false)) end function fg!(res, θ) (y, _) = value_and_gradient!( @@ -334,7 +334,7 @@ function OptimizationBase.instantiate_function( hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec if h == true && f.hess === nothing - prep_hess = prepare_hessian(f.f, soadtype, x, Constant(p)) + prep_hess = prepare_hessian(f.f, soadtype, x, Constant(p), strict=Val(false)) function hess(res, θ) hessian!(f.f, res, prep_hess, soadtype, θ, Constant(p)) end @@ -458,7 +458,7 @@ function OptimizationBase.instantiate_function( conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec if cons !== nothing && f.cons_h === nothing && cons_h == true - prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i)) + prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i), strict=Val(false)) for i in 1:num_cons] colores = getfield.(prep_cons_hess, :coloring_result) conshess_sparsity = getfield.(colores, :A) @@ -479,7 +479,7 @@ function OptimizationBase.instantiate_function( if cons !== nothing && f.lag_h === nothing && lag_h == true lag_extras = prepare_hessian( lagrangian, soadtype, x, Constant(one(eltype(x))), - Constant(ones(eltype(x), num_cons)), Constant(p)) + Constant(ones(eltype(x), num_cons)), Constant(p), strict=Val(false)) lag_hess_prototype = lag_extras.coloring_result.A lag_hess_colors = lag_extras.coloring_result.color