@@ -5,22 +5,35 @@ Given an IR value `arg` that corresponds to `u` in SciML's ODE ABI, split it int
55the DAECompiler internal ABI.
66"""
77function sciml_ode_split_u! (compact, line, arg, numstates)
8- nassgn = numstates[AssignedDiff]
98 ntotalstates = numstates[AssignedDiff] + numstates[UnassignedDiff] + numstates[Algebraic] + numstates[AlgebraicDerivative]
109
11- u_mm = @insert_node_here compact line view (arg, 1 : nassgn):: VectorViewType
12- u_unassgn = @insert_node_here compact line view (arg, (nassgn+ 1 ): (nassgn+ numstates[UnassignedDiff])):: VectorViewType
13- alg = @insert_node_here compact line view (arg, (nassgn+ numstates[UnassignedDiff]+ 1 ): ntotalstates):: VectorViewType
10+ u_mm = @insert_node_here compact line view (arg,
11+ 1 : numstates[AssignedDiff]):: VectorViewType
12+ u_unassgn = @insert_node_here compact line view (arg,
13+ (numstates[AssignedDiff] + 1 ): (numstates[AssignedDiff] + numstates[UnassignedDiff])):: VectorViewType
14+ alg = @insert_node_here compact line view (arg,
15+ (numstates[AssignedDiff] + numstates[UnassignedDiff] + 1 ): (numstates[AssignedDiff] + numstates[UnassignedDiff] + numstates[Algebraic])):: VectorViewType
16+ alg_derv = @insert_node_here compact line view (arg,
17+ (numstates[AssignedDiff] + numstates[UnassignedDiff] + numstates[Algebraic] + 1 ): ntotalstates):: VectorViewType
18+
19+ return (u_mm, u_unassgn, alg, alg_derv)
20+ end
1421
15- return (u_mm, u_unassgn, alg)
22+ function generate_ode_mass_matrix (nd, na)
23+ n = nd + na
24+ mass_matrix = zeros (Float64, n, n)
25+ for i in 1 : nd
26+ mass_matrix[i, i] = 1.0
27+ end
28+ return mass_matrix
1629end
1730
18- function make_odefunction (f)
19- ODEFunction (f)
31+ function make_odefunction (f, mass_matrix = LinearAlgebra . I, initf = nothing )
32+ ODEFunction (f; mass_matrix, initialization_data = (initf === nothing ? nothing : initialization_data_ode (initf)) )
2033end
2134
22- function make_odefunction (f, initf)
23- ODEFunction (f; initialization_data = SciMLBase. OverrideInitData (NonlinearProblem ((args... )-> nothing , nothing , nothing ), nothing , initf, nothing , nothing ) )
35+ function initialization_data_ode ( initf)
36+ return SciMLBase. OverrideInitData (NonlinearProblem ((args... )-> nothing , nothing , nothing ), nothing , initf, nothing , nothing )
2437end
2538
2639"""
3952```
4053
4154"""
42- function ode_factory_gen (result :: DAEIPOResult , ci:: CodeInstance , key:: TornCacheKey , world:: UInt , init_key:: Union{TornCacheKey, Nothing} )
43- @ccall jl_safe_printf ( " $key \n " :: Cstring ) :: Cvoid
55+ function ode_factory_gen (state :: TransformationState , ci:: CodeInstance , key:: TornCacheKey , world:: UInt , init_key:: Union{TornCacheKey, Nothing} )
56+ result = state . result
4457 torn_ci = find_matching_ci (ci-> isa (ci. owner, TornIRSpec) && ci. owner. key == key, ci. def, world)
4558 torn_ir = torn_ci. inferred
4659
@@ -65,7 +78,7 @@ function ode_factory_gen(result::DAEIPOResult, ci::CodeInstance, key::TornCacheK
6578 sicm = ()
6679 end
6780
68- odef_ci = rhs_finish! (result , ci, key, world, 1 )
81+ odef_ci = rhs_finish! (state , ci, key, world, 1 )
6982
7083 # Create a small opaque closure to adapt from SciML ABI to our own internal
7184 # ABI
@@ -76,7 +89,6 @@ function ode_factory_gen(result::DAEIPOResult, ci::CodeInstance, key::TornCacheK
7689 for var = 1 : length (result. var_to_diff)
7790 kind = classify_var (result. var_to_diff, key, var)
7891 kind == nothing && continue
79- @ccall jl_safe_printf (" $kind \n " :: Cstring ):: Cvoid
8092 numstates[kind] += 1
8193 (kind != AlgebraicDerivative) && push! (all_states, var)
8294 end
@@ -101,17 +113,17 @@ function ode_factory_gen(result::DAEIPOResult, ci::CodeInstance, key::TornCacheK
101113
102114 # out_du_mm, out_eq, in_u_mm, in_u_unassgn, in_alg
103115 nassgn = numstates[AssignedDiff]
104- ntotalstates = numstates[AssignedDiff] + numstates[UnassignedDiff] + numstates[Algebraic]
116+ ntotalstates = numstates[AssignedDiff] + numstates[UnassignedDiff] + numstates[Algebraic] + numstates[AlgebraicDerivative]
105117 out_du_mm = @insert_node_here oc_compact line view (du, 1 : nassgn):: VectorViewType
106118 out_eq = @insert_node_here oc_compact line view (du, (nassgn+ 1 ): ntotalstates):: VectorViewType
107119
108- (in_u_mm, in_u_unassgn, in_alg) = sciml_ode_split_u! (oc_compact, line, u, numstates)
120+ (in_u_mm, in_u_unassgn, in_alg, in_alg_derv ) = sciml_ode_split_u! (oc_compact, line, u, numstates)
109121
110122 # Call DAECompiler-generated RHS with internal ABI
111123 oc_sicm = @insert_node_here oc_compact line getfield (self, 1 ):: Tuple
112124
113125 # N.B: The ordering of arguments should match the ordering in the StateKind enum
114- @insert_node_here oc_compact line (:invoke )(odef_ci, oc_sicm, (), in_u_mm, in_u_unassgn, in_alg, out_du_mm, out_eq, t):: Nothing
126+ @insert_node_here oc_compact line (:invoke )(odef_ci, oc_sicm, (), in_u_mm, in_u_unassgn, in_alg, in_alg_derv, out_du_mm, out_eq, t):: Nothing
115127
116128 # Return
117129 @insert_node_here oc_compact line (return ):: Union{}
@@ -129,14 +141,14 @@ function ode_factory_gen(result::DAEIPOResult, ci::CodeInstance, key::TornCacheK
129141
130142 new_oc = @insert_node_here compact line (:new_opaque_closure )(argt, Union{}, Nothing, true , oc_source_method, sicm):: Core.OpaqueClosure true
131143
132- if init_key != = nothing
133- initf = init_uncompress_gen! (compact, result, ci, init_key, key, world)
134- odef = @insert_node_here compact line make_odefunction (new_oc, initf):: ODEFunction true
135- else
136- odef = @insert_node_here compact line make_odefunction (new_oc):: ODEFunction true
137- end
144+ nd = numstates[AssignedDiff] + numstates[UnassignedDiff]
145+ na = numstates[Algebraic] + numstates[AlgebraicDerivative]
146+ mass_matrix = na == 0 ? GlobalRef (LinearAlgebra, :I ) : @insert_node_here compact line generate_ode_mass_matrix (nd, na):: Matrix{Float64}
147+ initf = init_key != = nothing ? init_uncompress_gen! (compact, result, ci, init_key, key, world) : nothing
148+ odef = @insert_node_here compact line make_odefunction (new_oc, mass_matrix, initf):: ODEFunction true
138149
139- @insert_node_here compact line (return odef):: Core.OpaqueClosure true
150+ odef_and_n = @insert_node_here compact line tuple (odef, nd + na):: Tuple true
151+ @insert_node_here compact line (return odef_and_n):: Core.OpaqueClosure true
140152
141153 ir_factory = Compiler. finish (compact)
142154
0 commit comments