|  | 
|  | 1 | +abstract type AbstractDynamicOptProblem{uType, tType, isinplace} <: | 
|  | 2 | +              SciMLBase.AbstractODEProblem{uType, tType, isinplace} end | 
|  | 3 | + | 
|  | 4 | +struct DynamicOptSolution | 
|  | 5 | +    model::Any | 
|  | 6 | +    sol::ODESolution | 
|  | 7 | +    input_sol::Union{Nothing, ODESolution} | 
|  | 8 | +end | 
|  | 9 | + | 
|  | 10 | +function Base.show(io::IO, sol::DynamicOptSolution) | 
|  | 11 | +    println("retcode: ", sol.sol.retcode, "\n") | 
|  | 12 | + | 
|  | 13 | +    println("Optimal control solution for following model:\n") | 
|  | 14 | +    show(sol.model) | 
|  | 15 | + | 
|  | 16 | +    print("\n\nPlease query the model using sol.model, the solution trajectory for the system using sol.sol, or the solution trajectory for the controllers using sol.input_sol.") | 
|  | 17 | +end | 
|  | 18 | + | 
|  | 19 | +function JuMPDynamicOptProblem end | 
|  | 20 | +function InfiniteOptDynamicOptProblem end | 
|  | 21 | + | 
|  | 22 | +function warn_overdetermined(sys, u0map) | 
|  | 23 | +    constraintsys = get_constraintsystem(sys) | 
|  | 24 | +    if !isnothing(constraintsys) | 
|  | 25 | +        (length(constraints(constraintsys)) + length(u0map) > length(unknowns(sys))) && | 
|  | 26 | +            @warn "The control problem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The solvers will default to doing a nonlinear least-squares optimization." | 
|  | 27 | +    end | 
|  | 28 | +end | 
|  | 29 | + | 
|  | 30 | +""" | 
|  | 31 | +Generate the control function f(x, u, p, t) from the ODESystem.  | 
|  | 32 | +Input variables are automatically inferred but can be manually specified. | 
|  | 33 | +""" | 
|  | 34 | +function SciMLBase.ODEInputFunction{iip, specialize}(sys::ODESystem, | 
|  | 35 | +        dvs = unknowns(sys), | 
|  | 36 | +        ps = parameters(sys), u0 = nothing, | 
|  | 37 | +        inputs = unbound_inputs(sys), | 
|  | 38 | +        disturbance_inputs = disturbances(sys); | 
|  | 39 | +        version = nothing, tgrad = false, | 
|  | 40 | +        jac = false, controljac = false, | 
|  | 41 | +        p = nothing, t = nothing, | 
|  | 42 | +        eval_expression = false, | 
|  | 43 | +        sparse = false, simplify = false, | 
|  | 44 | +        eval_module = @__MODULE__, | 
|  | 45 | +        steady_state = false, | 
|  | 46 | +        checkbounds = false, | 
|  | 47 | +        sparsity = false, | 
|  | 48 | +        analytic = nothing, | 
|  | 49 | +        split_idxs = nothing, | 
|  | 50 | +        initialization_data = nothing, | 
|  | 51 | +        cse = true, | 
|  | 52 | +        kwargs...) where {iip, specialize} | 
|  | 53 | +    (f), _, _ = generate_control_function( | 
|  | 54 | +        sys, inputs, disturbance_inputs; eval_module, cse, kwargs...) | 
|  | 55 | + | 
|  | 56 | +    if tgrad | 
|  | 57 | +        tgrad_gen = generate_tgrad(sys, dvs, ps; | 
|  | 58 | +            simplify = simplify, | 
|  | 59 | +            expression = Val{true}, | 
|  | 60 | +            expression_module = eval_module, cse, | 
|  | 61 | +            checkbounds = checkbounds, kwargs...) | 
|  | 62 | +        tgrad_oop, tgrad_iip = eval_or_rgf.(tgrad_gen; eval_expression, eval_module) | 
|  | 63 | +        _tgrad = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(tgrad_oop, tgrad_iip) | 
|  | 64 | +    else | 
|  | 65 | +        _tgrad = nothing | 
|  | 66 | +    end | 
|  | 67 | + | 
|  | 68 | +    if jac | 
|  | 69 | +        jac_gen = generate_jacobian(sys, dvs, ps; | 
|  | 70 | +            simplify = simplify, sparse = sparse, | 
|  | 71 | +            expression = Val{true}, | 
|  | 72 | +            expression_module = eval_module, cse, | 
|  | 73 | +            checkbounds = checkbounds, kwargs...) | 
|  | 74 | +        jac_oop, jac_iip = eval_or_rgf.(jac_gen; eval_expression, eval_module) | 
|  | 75 | + | 
|  | 76 | +        _jac = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(jac_oop, jac_iip) | 
|  | 77 | +    else | 
|  | 78 | +        _jac = nothing | 
|  | 79 | +    end | 
|  | 80 | + | 
|  | 81 | +    if controljac | 
|  | 82 | +        cjac_gen = generate_control_jacobian(sys, dvs, ps; | 
|  | 83 | +            simplify = simplify, sparse = sparse, | 
|  | 84 | +            expression = Val{true}, | 
|  | 85 | +            expression_module = eval_module, cse, | 
|  | 86 | +            checkbounds = checkbounds, kwargs...) | 
|  | 87 | +        cjac_oop, cjac_iip = eval_or_rgf.(cjac_gen; eval_expression, eval_module) | 
|  | 88 | + | 
|  | 89 | +        _cjac = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(cjac_oop, cjac_iip) | 
|  | 90 | +    else | 
|  | 91 | +        _cjac = nothing | 
|  | 92 | +    end | 
|  | 93 | + | 
|  | 94 | +    M = calculate_massmatrix(sys) | 
|  | 95 | +    _M = if sparse && !(u0 === nothing || M === I) | 
|  | 96 | +        SparseArrays.sparse(M) | 
|  | 97 | +    elseif u0 === nothing || M === I | 
|  | 98 | +        M | 
|  | 99 | +    else | 
|  | 100 | +        ArrayInterface.restructure(u0 .* u0', M) | 
|  | 101 | +    end | 
|  | 102 | + | 
|  | 103 | +    observedfun = ObservedFunctionCache( | 
|  | 104 | +        sys; steady_state, eval_expression, eval_module, checkbounds, cse) | 
|  | 105 | + | 
|  | 106 | +    if sparse | 
|  | 107 | +        uElType = u0 === nothing ? Float64 : eltype(u0) | 
|  | 108 | +        W_prototype = similar(W_sparsity(sys), uElType) | 
|  | 109 | +        controljac_prototype = similar(calculate_control_jacobian(sys), uElType) | 
|  | 110 | +    else | 
|  | 111 | +        W_prototype = nothing | 
|  | 112 | +        controljac_prototype = nothing | 
|  | 113 | +    end | 
|  | 114 | + | 
|  | 115 | +    ODEInputFunction{iip, specialize}(f; | 
|  | 116 | +        sys = sys, | 
|  | 117 | +        jac = _jac === nothing ? nothing : _jac, | 
|  | 118 | +        controljac = _cjac === nothing ? nothing : _cjac, | 
|  | 119 | +        tgrad = _tgrad === nothing ? nothing : _tgrad, | 
|  | 120 | +        mass_matrix = _M, | 
|  | 121 | +        jac_prototype = W_prototype, | 
|  | 122 | +        controljac_prototype = controljac_prototype, | 
|  | 123 | +        observed = observedfun, | 
|  | 124 | +        sparsity = sparsity ? W_sparsity(sys) : nothing, | 
|  | 125 | +        analytic = analytic, | 
|  | 126 | +        initialization_data) | 
|  | 127 | +end | 
|  | 128 | + | 
|  | 129 | +function SciMLBase.ODEInputFunction(sys::AbstractODESystem, args...; kwargs...) | 
|  | 130 | +    ODEInputFunction{true}(sys, args...; kwargs...) | 
|  | 131 | +end | 
|  | 132 | + | 
|  | 133 | +function SciMLBase.ODEInputFunction{true}(sys::AbstractODESystem, args...; | 
|  | 134 | +        kwargs...) | 
|  | 135 | +    ODEInputFunction{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...) | 
|  | 136 | +end | 
|  | 137 | + | 
|  | 138 | +function SciMLBase.ODEInputFunction{false}(sys::AbstractODESystem, args...; | 
|  | 139 | +        kwargs...) | 
|  | 140 | +    ODEInputFunction{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...) | 
|  | 141 | +end | 
|  | 142 | + | 
|  | 143 | +# returns the JuMP timespan, the number of steps, and whether it is a free time problem. | 
|  | 144 | +function process_tspan(tspan, dt, steps) | 
|  | 145 | +    is_free_time = false | 
|  | 146 | +    if isnothing(dt) && isnothing(steps) | 
|  | 147 | +        error("Must provide either the dt or the number of intervals to the collocation solvers (JuMP, InfiniteOpt, CasADi).") | 
|  | 148 | +    elseif symbolic_type(tspan[1]) === ScalarSymbolic() || | 
|  | 149 | +           symbolic_type(tspan[2]) === ScalarSymbolic() | 
|  | 150 | +        isnothing(steps) && | 
|  | 151 | +            error("Free final time problems require specifying the number of steps using the keyword arg `steps`, rather than dt.") | 
|  | 152 | +        isnothing(dt) || | 
|  | 153 | +            @warn "Specified dt for free final time problem. This will be ignored; dt will be determined by the number of timesteps." | 
|  | 154 | + | 
|  | 155 | +        return steps, true | 
|  | 156 | +    else | 
|  | 157 | +        isnothing(steps) || | 
|  | 158 | +            @warn "Specified number of steps for problem with concrete tspan. This will be ignored; number of steps will be determined by dt." | 
|  | 159 | + | 
|  | 160 | +        return length(tspan[1]:dt:tspan[2]), false | 
|  | 161 | +    end | 
|  | 162 | +end | 
0 commit comments