47
47
48
48
is_explicit (tableau) = tableau isa DiffEqBase. ExplicitRKTableau
49
49
50
- """
51
- Generate the control function f(x, u, p, t) from the ODESystem.
52
- Input variables are automatically inferred but can be manually specified.
53
- """
54
- function SciMLBase. ODEInputFunction {iip, specialize} (sys:: System ,
55
- dvs = unknowns (sys),
56
- ps = parameters (sys), u0 = nothing ,
50
+ @fallback_iip_specialize function SciMLBase. ODEInputFunction {iip, specialize} (sys:: System ;
57
51
inputs = unbound_inputs (sys),
58
- disturbance_inputs = disturbances (sys);
59
- version = nothing , tgrad = false ,
52
+ disturbance_inputs = disturbances (sys),
53
+ u0 = nothing , tgrad = false ,
60
54
jac = false , controljac = false ,
61
55
p = nothing , t = nothing ,
62
56
eval_expression = false ,
@@ -66,7 +60,6 @@ function SciMLBase.ODEInputFunction{iip, specialize}(sys::System,
66
60
checkbounds = false ,
67
61
sparsity = false ,
68
62
analytic = nothing ,
69
- split_idxs = nothing ,
70
63
initialization_data = nothing ,
71
64
cse = true ,
72
65
kwargs... ) where {iip, specialize}
@@ -75,61 +68,49 @@ function SciMLBase.ODEInputFunction{iip, specialize}(sys::System,
75
68
f = f[1 ]
76
69
77
70
if tgrad
78
- tgrad_gen = generate_tgrad (sys, dvs, ps ;
71
+ _tgrad = generate_tgrad (sys;
79
72
simplify = simplify,
80
73
expression = Val{true },
74
+ wrap_gfw = Val{true },
81
75
expression_module = eval_module, cse,
82
76
checkbounds = checkbounds, kwargs... )
83
- tgrad_oop, tgrad_iip = eval_or_rgf .(tgrad_gen; eval_expression, eval_module)
84
- _tgrad = GeneratedFunctionWrapper {(2, 3, is_split(sys))} (tgrad_oop, tgrad_iip)
85
77
else
86
78
_tgrad = nothing
87
79
end
88
80
89
81
if jac
90
- jac_gen = generate_jacobian (sys, dvs, ps ;
82
+ _jac = generate_jacobian (sys;
91
83
simplify = simplify, sparse = sparse,
92
84
expression = Val{true },
85
+ wrap_gfw = Val{true },
93
86
expression_module = eval_module, cse,
94
87
checkbounds = checkbounds, kwargs... )
95
- jac_oop, jac_iip = eval_or_rgf .(jac_gen; eval_expression, eval_module)
96
-
97
- _jac = GeneratedFunctionWrapper {(2, 3, is_split(sys))} (jac_oop, jac_iip)
98
88
else
99
89
_jac = nothing
100
90
end
101
91
102
92
if controljac
103
- cjac_gen = generate_control_jacobian (sys, dvs, ps ;
93
+ _cjac = generate_control_jacobian (sys;
104
94
simplify = simplify, sparse = sparse,
105
- expression = Val{true },
95
+ expression = Val{true }, wrap_gfw = Val{ true },
106
96
expression_module = eval_module, cse,
107
97
checkbounds = checkbounds, kwargs... )
108
- cjac_oop, cjac_iip = eval_or_rgf .(cjac_gen; eval_expression, eval_module)
109
-
110
- _cjac = GeneratedFunctionWrapper {(2, 3, is_split(sys))} (cjac_oop, cjac_iip)
111
98
else
112
99
_cjac = nothing
113
100
end
114
101
115
102
M = calculate_massmatrix (sys)
116
- _M = if sparse && ! (u0 === nothing || M === I)
117
- SparseArrays. sparse (M)
118
- elseif u0 === nothing || M === I
119
- M
120
- else
121
- ArrayInterface. restructure (u0 .* u0' , M)
122
- end
103
+ _M = concrete_massmatrix (M; sparse, u0)
123
104
124
105
observedfun = ObservedFunctionCache (
125
106
sys; steady_state, eval_expression, eval_module, checkbounds, cse)
126
107
108
+ _W_sparsity = W_sparsity (sys)
109
+ W_prototype = calculate_W_prototype (_W_sparsity; u0, sparse)
127
110
if sparse
128
111
uElType = u0 === nothing ? Float64 : eltype (u0)
129
- W_prototype = similar (W_sparsity (sys), uElType)
130
112
controljac_prototype = similar (calculate_control_jacobian (sys), uElType)
131
113
else
132
- W_prototype = nothing
133
114
controljac_prototype = nothing
134
115
end
135
116
@@ -142,25 +123,11 @@ function SciMLBase.ODEInputFunction{iip, specialize}(sys::System,
142
123
jac_prototype = W_prototype,
143
124
controljac_prototype = controljac_prototype,
144
125
observed = observedfun,
145
- sparsity = sparsity ? W_sparsity (sys) : nothing ,
126
+ sparsity = sparsity ? _W_sparsity : nothing ,
146
127
analytic = analytic,
147
128
initialization_data)
148
129
end
149
130
150
- function SciMLBase. ODEInputFunction (sys:: System , args... ; kwargs... )
151
- ODEInputFunction {true} (sys, args... ; kwargs... )
152
- end
153
-
154
- function SciMLBase. ODEInputFunction {true} (sys:: System , args... ;
155
- kwargs... )
156
- ODEInputFunction {true, SciMLBase.AutoSpecialize} (sys, args... ; kwargs... )
157
- end
158
-
159
- function SciMLBase. ODEInputFunction {false} (sys:: System , args... ;
160
- kwargs... )
161
- ODEInputFunction {false, SciMLBase.FullSpecialize} (sys, args... ; kwargs... )
162
- end
163
-
164
131
# returns the JuMP timespan, the number of steps, and whether it is a free time problem.
165
132
function process_tspan (tspan, dt, steps)
166
133
is_free_time = false
0 commit comments