@@ -589,11 +589,9 @@ function get_optim_functions(
589589 grad_backend:: AbstractADType ,
590590 jac_backend :: AbstractADType
591591) where JNT<: Real
592- model, transcription = mpc. estim. model, mpc. transcription
593- # TODO : fix type of all cache to ::Vector{JNT} (verify performance difference with and w/o)
594- # TODO : mêmes choses pour le MHE
595- # --------------------- update simulation function ------------------------------------
596- function update_simulations! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
592+ # ------ update simulation function (all args after `mpc` are mutated) ----------------
593+ function update_simulations! (Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
594+ model, transcription = mpc. estim. model, mpc. transcription
597595 U0 = getU0! (U0, mpc, Z̃)
598596 ΔŨ = getΔŨ! (ΔŨ, mpc, transcription, Z̃)
599597 Ŷ0, x̂0end = predict! (Ŷ0, x̂0end, X̂0, Û0, mpc, model, transcription, U0, Z̃)
@@ -605,36 +603,38 @@ function get_optim_functions(
605603 return nothing
606604 end
607605 # ----- common cache for Jfunc, gfuncs, geqfuncs called with floats -------------------
606+ model = mpc. estim. model
608607 nu, ny, nx̂, nϵ, Hp, Hc = model. nu, model. ny, mpc. estim. nx̂, mpc. nϵ, mpc. Hp, mpc. Hc
609608 ng, nc, neq = length (mpc. con. i_g), mpc. con. nc, mpc. con. neq
610609 nZ̃, nU, nŶ, nX̂ = length (mpc. Z̃), Hp* nu, Hp* ny, Hp* nx̂
611610 nΔŨ, nUe, nŶe = nu* Hc + nϵ, nU + nu, nŶ + ny
612- myNaN = convert (JNT, NaN )
613- Z̃ = fill (myNaN, nZ̃) # NaN to force update_simulations! at first call
614- ΔŨ = zeros (JNT, nΔŨ)
615- x̂0end = zeros (JNT, nx̂)
616- Ue, Ŷe = zeros (JNT, nUe), zeros (JNT, nŶe)
617- U0, Ŷ0 = zeros (JNT, nU), zeros (JNT, nŶ)
618- Û0, X̂0 = zeros (JNT, nU), zeros (JNT, nX̂)
619- gc, g = zeros (JNT, nc), zeros (JNT, ng)
620- geq = zeros (JNT, neq)
621- # ---------------------- objective function ------------------------------------------
611+ myNaN = convert (JNT, NaN ) # NaN to force update_simulations! at first call:
612+ Z̃ :: Vector{JNT} = fill (myNaN, nZ̃)
613+ ΔŨ:: Vector{JNT} = zeros (JNT, nΔŨ)
614+ x̂0end:: Vector{JNT} = zeros (JNT, nx̂)
615+ Ue:: Vector{JNT} , Ŷe:: Vector{JNT} = zeros (JNT, nUe), zeros (JNT, nŶe)
616+ U0:: Vector{JNT} , Ŷ0:: Vector{JNT} = zeros (JNT, nU), zeros (JNT, nŶ)
617+ Û0:: Vector{JNT} , X̂0:: Vector{JNT} = zeros (JNT, nU), zeros (JNT, nX̂)
618+ gc:: Vector{JNT} , g:: Vector{JNT} = zeros (JNT, nc), zeros (JNT, ng)
619+ geq:: Vector{JNT} = zeros (JNT, neq)
620+ # ---------------------- objective function -------------------------------------------
622621 function Jfunc (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
623622 if isdifferent (Z̃arg, Z̃)
624623 Z̃ .= Z̃arg
625- update_simulations! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
624+ update_simulations! (Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
626625 end
627626 return obj_nonlinprog! (Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ):: T
628627 end
629- function Jfunc! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
630- update_simulations! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
628+ function Jfunc! (Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
629+ update_simulations! (Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
631630 return obj_nonlinprog! (Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
632631 end
633632 Z̃_∇J = fill (myNaN, nZ̃)
634633 ∇J_context = (
634+ Constant (mpc),
635635 Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0),
636636 Cache (Û0), Cache (X̂0),
637- Cache (gc), Cache (g), Cache (geq)
637+ Cache (gc), Cache (g), Cache (geq),
638638 )
639639 ∇J_prep = prepare_gradient (Jfunc!, grad_backend, Z̃_∇J, ∇J_context... )
640640 ∇J = Vector {JNT} (undef, nZ̃)
@@ -657,26 +657,26 @@ function get_optim_functions(
657657 gfunc_i = function (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
658658 if isdifferent (Z̃arg, Z̃)
659659 Z̃ .= Z̃arg
660- update_simulations! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
660+ update_simulations! (Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
661661 end
662662 return g[i]:: T
663663 end
664664 gfuncs[i] = gfunc_i
665665 end
666- function gfunc! (g, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, geq)
667- return update_simulations! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
666+ function gfunc! (g, Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, geq)
667+ return update_simulations! (Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
668668 end
669669 Z̃_∇g = fill (myNaN, nZ̃)
670670 ∇g_context = (
671+ Constant (mpc),
671672 Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0),
672673 Cache (Û0), Cache (X̂0),
673- Cache (gc), Cache (geq)
674+ Cache (gc), Cache (geq),
674675 )
675- # temporarily enable all the inequality constraints for sparsity pattern detection:
676- i_g_old = copy (mpc. con. i_g)
677- mpc. con. i_g .= true
676+ # temporarily enable all the inequality constraints for sparsity detection:
677+ mpc. con. i_g[1 : end - nc] .= true
678678 ∇g_prep = prepare_jacobian (gfunc!, g, jac_backend, Z̃_∇g, ∇g_context... )
679- mpc. con. i_g .= i_g_old
679+ mpc. con. i_g[ 1 : end - nc] .= false
680680 ∇g = init_diffmat (JNT, jac_backend, ∇g_prep, nZ̃, ng)
681681 ∇gfuncs! = Vector {Function} (undef, ng)
682682 for i in eachindex (∇gfuncs!)
@@ -705,17 +705,18 @@ function get_optim_functions(
705705 geqfunc_i = function (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
706706 if isdifferent (Z̃arg, Z̃)
707707 Z̃ .= Z̃arg
708- update_simulations! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
708+ update_simulations! (Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
709709 end
710710 return geq[i]:: T
711711 end
712712 geqfuncs[i] = geqfunc_i
713713 end
714- function geqfunc! (geq, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g)
715- return update_simulations! (Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
714+ function geqfunc! (geq, Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g)
715+ return update_simulations! (Z̃, mpc, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
716716 end
717717 Z̃_∇geq = fill (myNaN, nZ̃)
718718 ∇geq_context = (
719+ Constant (mpc),
719720 Cache (ΔŨ), Cache (x̂0end), Cache (Ue), Cache (Ŷe), Cache (U0), Cache (Ŷ0),
720721 Cache (Û0), Cache (X̂0),
721722 Cache (gc), Cache (g)
0 commit comments