@@ -582,43 +582,14 @@ function get_optim_functions(
582582 jac_backend :: AbstractADType
583583) where JNT<: Real
584584 model, transcription = mpc. estim. model, mpc. transcription
585- nu, ny, nx̂, nϵ, Hp, Hc = model. nu, model. ny, mpc. estim. nx̂, mpc. nϵ, mpc. Hp, mpc. Hc
586- ng, nc, neq = length (mpc. con. i_g), mpc. con. nc, mpc. con. neq
587- nZ̃, nU, nŶ, nX̂ = length (mpc. Z̃), Hp* nu, Hp* ny, Hp* nx̂
588- nΔŨ, nUe, nŶe = nu* Hc + nϵ, nU + nu, nŶ + ny
589- Ncache = nZ̃ + 3
590- myNaN = convert (JNT, NaN ) # fill Z̃ with NaNs to force update_simulations! at 1st call:
591- # ---------------------- differentiation cache ---------------------------------------
592- Z̃_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (fill (myNaN, nZ̃), Ncache)
593- ΔŨ_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nΔŨ), Ncache)
594- x̂0end_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nx̂), Ncache)
595- Ŷe_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nŶe), Ncache)
596- Ue_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nUe), Ncache)
597- Ŷ0_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nŶ), Ncache)
598- U0_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nU), Ncache)
599- Û0_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nU), Ncache)
600- X̂0_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nX̂), Ncache)
601- gc_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, nc), Ncache)
602- g_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, ng), Ncache)
603- geq_cache:: DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache (zeros (JNT, neq), Ncache)
604585 # --------------------- update simulation function ------------------------------------
605- function update_simulations! (
606- Z̃arg:: Union{NTuple{N, T}, AbstractVector{T}} , Z̃cache
607- ) where {N, T<: Real }
608- if isdifferent (Z̃cache, Z̃arg)
609- for i in eachindex (Z̃cache)
586+ function update_simulations! (Z̃arg, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
587+ if isdifferent (Z̃arg, Z̃)
588+ for i in eachindex (Z̃)
610589 # Z̃cache .= Z̃arg is type unstable with Z̃arg::NTuple{N, FowardDiff.Dual}
611- Z̃cache [i] = Z̃arg[i]
590+ Z̃ [i] = Z̃arg[i]
612591 end
613- Z̃ = Z̃cache
614592 ϵ = (nϵ ≠ 0 ) ? Z̃[end ] : zero (T) # ϵ = 0 if nϵ == 0 (meaning no relaxation)
615- ΔŨ = get_tmp (ΔŨ_cache, T)
616- x̂0end = get_tmp (x̂0end_cache, T)
617- Ue, Ŷe = get_tmp (Ue_cache, T), get_tmp (Ŷe_cache, T)
618- U0, Ŷ0 = get_tmp (U0_cache, T), get_tmp (Ŷ0_cache, T)
619- X̂0, Û0 = get_tmp (X̂0_cache, T), get_tmp (Û0_cache, T)
620- gc, g = get_tmp (gc_cache, T), get_tmp (g_cache, T)
621- geq = get_tmp (geq_cache, T)
622593 U0 = getU0! (U0, mpc, Z̃)
623594 ΔŨ = getΔŨ! (ΔŨ, mpc, transcription, Z̃)
624595 Ŷ0, x̂0end = predict! (Ŷ0, x̂0end, X̂0, Û0, mpc, model, transcription, U0, Z̃)
@@ -629,78 +600,151 @@ function get_optim_functions(
629600 end
630601 return nothing
631602 end
632- # --------------------- objective functions -------------------------------------------
603+ # ---------------------- JNT vectors cache --------------------------------------------
604+ nu, ny, nx̂, nϵ, Hp, Hc = model. nu, model. ny, mpc. estim. nx̂, mpc. nϵ, mpc. Hp, mpc. Hc
605+ ng, nc, neq = length (mpc. con. i_g), mpc. con. nc, mpc. con. neq
606+ nZ̃, nU, nŶ, nX̂ = length (mpc. Z̃), Hp* nu, Hp* ny, Hp* nx̂
607+ nΔŨ, nUe, nŶe = nu* Hc + nϵ, nU + nu, nŶ + ny
608+ myNaN = convert (JNT, NaN )
609+ Z̃ = fill (myNaN, nZ̃)
610+ ΔŨ = zeros (JNT, nΔŨ)
611+ x̂0end = zeros (JNT, nx̂)
612+ Ue, Ŷe = zeros (JNT, nUe), zeros (JNT, nŶe)
613+ U0, Ŷ0 = zeros (JNT, nU), zeros (JNT, nŶ)
614+ Û0, X̂0 = zeros (JNT, nU), zeros (JNT, nX̂)
615+ gc, g = zeros (JNT, nc), zeros (JNT, ng)
616+ geq = zeros (JNT, neq)
617+
618+ # WIP: still receive INVALID_MODEL once and a while, needs to investigate.
619+
620+
633621 function Jfunc (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
634- update_simulations! (Z̃arg, get_tmp (Z̃_cache, T))
635- ΔŨ = get_tmp (ΔŨ_cache, T)
636- Ue, Ŷe = get_tmp (Ue_cache, T), get_tmp (Ŷe_cache, T)
637- U0, Ŷ0 = get_tmp (U0_cache, T), get_tmp (Ŷ0_cache, T)
622+ update_simulations! (Z̃arg, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
638623 return obj_nonlinprog! (Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ):: T
639624 end
640- function Jfunc_vec (Z̃arg:: AbstractVector{T} ) where T<: Real
641- update_simulations! (Z̃arg, get_tmp (Z̃_cache, T))
642- ΔŨ = get_tmp (ΔŨ_cache, T)
643- Ue, Ŷe = get_tmp (Ue_cache, T), get_tmp (Ŷe_cache, T)
644- U0, Ŷ0 = get_tmp (U0_cache, T), get_tmp (Ŷ0_cache, T)
645- return obj_nonlinprog! (Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ):: T
625+ function Jfunc_vec! (Z̃arg, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
626+ update_simulations! (Z̃arg, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
627+ return obj_nonlinprog! (Ŷ0, U0, mpc, model, Ue, Ŷe, ΔŨ)
646628 end
629+
630+
631+
632+ gfuncs = Vector {Function} (undef, ng)
633+ for i in eachindex (gfuncs)
634+ func_i = function (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
635+ update_simulations! (Z̃arg, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
636+ return g[i]:: T
637+ end
638+ gfuncs[i] = func_i
639+ end
640+ function gfunc_vec! (g, Z̃arg, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, geq)
641+ update_simulations! (Z̃arg, Z̃, ΔŨ, x̂0end, Ue, Ŷe, U0, Ŷ0, Û0, X̂0, gc, g, geq)
642+ return g
643+ end
644+
645+
646+
647+
648+
649+
650+
651+
652+ Z̃_∇J = fill (myNaN, nZ̃)
653+ ΔŨ_∇J = zeros (JNT, nΔŨ)
654+ x̂0end_∇J = zeros (JNT, nx̂)
655+ Ue_∇J, Ŷe_∇J = zeros (JNT, nUe), zeros (JNT, nŶe)
656+ U0_∇J, Ŷ0_∇J = zeros (JNT, nU), zeros (JNT, nŶ)
657+ Û0_∇J, X̂0_∇J = zeros (JNT, nU), zeros (JNT, nX̂)
658+ gc_∇J, g_∇J = zeros (JNT, nc), zeros (JNT, ng)
659+ geq_∇J = zeros (JNT, neq)
660+
661+
662+
663+
664+
665+
666+
647667 Z̃_∇J = fill (myNaN, nZ̃)
648668 ∇J = Vector {JNT} (undef, nZ̃) # gradient of objective J
649- ∇J_prep = prepare_gradient (Jfunc_vec, grad_backend, Z̃_∇J)
669+ ∇J_context = (
670+ Cache (Z̃_∇J), Cache (ΔŨ_∇J), Cache (x̂0end_∇J),
671+ Cache (Ue_∇J), Cache (Ŷe_∇J),
672+ Cache (U0_∇J), Cache (Ŷ0_∇J),
673+ Cache (Û0_∇J), Cache (X̂0_∇J),
674+ Cache (gc_∇J), Cache (g_∇J), Cache (geq_∇J)
675+ )
676+ ∇J_prep = prepare_gradient (Jfunc_vec!, grad_backend, Z̃_∇J, ∇J_context... )
650677 ∇Jfunc! = if nZ̃ == 1
651- function (Z̃arg:: T ) where T <: Real
678+ function (Z̃arg)
652679 Z̃_∇J .= Z̃arg
653- gradient! (Jfunc_vec, ∇J, ∇J_prep, grad_backend, Z̃_∇J)
680+ gradient! (Jfunc_vec! , ∇J, ∇J_prep, grad_backend, Z̃_∇J, ∇J_context ... )
654681 return ∇J[begin ] # univariate syntax, see JuMP.@operator doc
655682 end
656683 else
657684 function (∇J:: AbstractVector{T} , Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
658685 Z̃_∇J .= Z̃arg
659- gradient! (Jfunc_vec, ∇J, ∇J_prep, grad_backend, Z̃_∇J)
686+ gradient! (Jfunc_vec! , ∇J, ∇J_prep, grad_backend, Z̃_∇J, ∇J_context ... )
660687 return ∇J # multivariate syntax, see JuMP.@operator doc
661688 end
662689 end
663- # --------------------- inequality constraint functions -------------------------------
664- gfuncs = Vector {Function} (undef, ng)
665- for i in eachindex (gfuncs)
666- func_i = function (Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
667- update_simulations! (Z̃arg, get_tmp (Z̃_cache, T))
668- g = get_tmp (g_cache, T)
669- return g[i]:: T
670- end
671- gfuncs[i] = func_i
672- end
673- function gfunc_vec! (g, Z̃vec:: AbstractVector{T} ) where T<: Real
674- update_simulations! (Z̃vec, get_tmp (Z̃_cache, T))
675- g .= get_tmp (g_cache, T)
676- return g
677- end
690+
691+
692+
693+
694+
695+
696+
697+ Z̃_∇g = fill (myNaN, nZ̃)
698+ ΔŨ_∇g = zeros (JNT, nΔŨ)
699+ x̂0end_∇g = zeros (JNT, nx̂)
700+ Ue_∇g, Ŷe_∇g = zeros (JNT, nUe), zeros (JNT, nŶe)
701+ U0_∇g, Ŷ0_∇g = zeros (JNT, nU), zeros (JNT, nŶ)
702+ Û0_∇g, X̂0_∇g = zeros (JNT, nU), zeros (JNT, nX̂)
703+ gc_∇g, g_∇g = zeros (JNT, nc), zeros (JNT, ng)
704+ geq_∇g = zeros (JNT, neq)
705+
706+
707+
708+
709+
678710 Z̃_∇g = fill (myNaN, nZ̃)
679711 g_vec = Vector {JNT} (undef, ng)
680712 ∇g = Matrix {JNT} (undef, ng, nZ̃) # Jacobian of inequality constraints g
681- ∇g_prep = prepare_jacobian (gfunc_vec!, g_vec, jac_backend, Z̃_∇g)
713+ ∇g_context = (
714+ Cache (Z̃_∇g), Cache (ΔŨ_∇g), Cache (x̂0end_∇g),
715+ Cache (Ue_∇g), Cache (Ŷe_∇g),
716+ Cache (U0_∇g), Cache (Ŷ0_∇g),
717+ Cache (Û0_∇g), Cache (X̂0_∇g),
718+ Cache (gc_∇g), Cache (geq_∇g)
719+ )
720+ ∇g_prep = prepare_jacobian (gfunc_vec!, g_vec, jac_backend, Z̃_∇g, ∇g_context... )
682721 ∇gfuncs! = Vector {Function} (undef, ng)
683722 for i in eachindex (∇gfuncs!)
684723 ∇gfuncs![i] = if nZ̃ == 1
685724 function (Z̃arg:: T ) where T<: Real
686725 if isdifferent (Z̃arg, Z̃_∇g)
687726 Z̃_∇g .= Z̃arg
688- jacobian! (gfunc_vec!, g_vec, ∇g, ∇g_prep, jac_backend, Z̃_∇g)
727+ jacobian! (
728+ gfunc_vec!, g_vec, ∇g, ∇g_prep, jac_backend, Z̃_∇g, ∇g_context...
729+ )
689730 end
690731 return ∇g[i, begin ] # univariate syntax, see JuMP.@operator doc
691732 end
692733 else
693734 function (∇g_i, Z̃arg:: Vararg{T, N} ) where {N, T<: Real }
694735 if isdifferent (Z̃arg, Z̃_∇g)
695736 Z̃_∇g .= Z̃arg
696- jacobian! (gfunc_vec!, g_vec, ∇g, ∇g_prep, jac_backend, Z̃_∇g)
737+ jacobian! (
738+ gfunc_vec!, g_vec, ∇g, ∇g_prep, jac_backend, Z̃_∇g, ∇g_context...
739+ )
697740 end
698741 return ∇g_i .= @views ∇g[i, :] # multivariate syntax, see JuMP.@operator doc
699742 end
700743 end
701744 end
702745 # --------------------- equality constraint functions ---------------------------------
703746 geqfuncs = Vector {Function} (undef, neq)
747+ #=
704748 for i in eachindex(geqfuncs)
705749 func_i = function (Z̃arg::Vararg{T, N}) where {N, T<:Real}
706750 update_simulations!(Z̃arg, get_tmp(Z̃_cache, T))
@@ -718,7 +762,9 @@ function get_optim_functions(
718762 geq_vec = Vector{JNT}(undef, neq)
719763 ∇geq = Matrix{JNT}(undef, neq, nZ̃) # Jacobian of equality constraints geq
720764 ∇geq_prep = prepare_jacobian(geqfunc_vec!, geq_vec, jac_backend, Z̃_∇geq)
765+ =#
721766 ∇geqfuncs! = Vector {Function} (undef, neq)
767+ #=
722768 for i in eachindex(∇geqfuncs!)
723769 # only multivariate syntax, univariate is impossible since nonlinear equality
724770 # constraints imply MultipleShooting, thus input increment ΔU and state X̂0 in Z̃:
@@ -731,6 +777,7 @@ function get_optim_functions(
731777 return ∇geq_i .= @views ∇geq[i, :]
732778 end
733779 end
780+ =#
734781 return Jfunc, ∇Jfunc!, gfuncs, ∇gfuncs!, geqfuncs, ∇geqfuncs!
735782end
736783
0 commit comments