@@ -39,101 +39,51 @@ function make_daefunction(f, initf)
3939 DAEFunction (f; initialization_data = SciMLBase. OverrideInitData (NonlinearProblem ((args... )-> nothing , nothing , nothing ), nothing , initf, nothing , nothing , Val {false} ()))
4040end
4141
42- """
43- dae_factory_gen(ci, key)
44-
45- Generate the `factory` function for CodeInstance `ci`, returning a DAEFunction.
46- The resulting function is roughly:
47-
48- ```
49- function factory(settings, f)
50- # Run all parts of `f` that do not depend on state
51- state_invariant_pieces = f_state_invariant()
52- f! = %new_opaque_closure(f_rhs, state_invariant_pieces)
53- DAEFunction(f!), differential_vars
42+ function continuous_variables (state:: TransformationState )
43+ filter (var -> varkind (state, var) == Intrinsics. Continuous, 1 : length (state. result. var_to_diff))
5444end
55- ```
56-
57- """
58- function dae_factory_gen (state:: TransformationState , ci:: CodeInstance , key:: TornCacheKey , world:: UInt , settings:: Settings , init_key:: Union{TornCacheKey, Nothing} )
59- result = state. result
60- torn_ci = find_matching_ci (ci-> isa (ci. owner, TornIRSpec) && ci. owner. key == key, ci. def, world)
61- torn_ir = torn_ci. inferred
62-
63- (;ir_sicm) = torn_ir
6445
65- ir_factory = copy (ci. inferred. ir)
66- pushfirst! (ir_factory. argtypes, Settings)
67- pushfirst! (ir_factory. argtypes, typeof (factory))
68- compact = IncrementalCompact (ir_factory)
69-
70- local line
71- if ir_sicm != = nothing
72- sicm_ci = find_matching_ci (ci-> isa (ci. owner, SICMSpec) && ci. owner. key == key, ci. def, world)
73- @assert sicm_ci != = nothing
74-
75- line = result. ir[SSAValue (1 )][:line ]
76- param_list = flatten_parameter! (Compiler. fallback_lattice, compact, ci. inferred. ir. argtypes[1 : end ], argn-> Argument (2 + argn), line, settings)
77- sicm = @insert_instruction_here compact line settings invoke (param_list, sicm_ci):: Tuple
78- else
79- sicm = ()
80- end
81-
82- argt = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase. NullParameters, Float64}
83-
84- daef_ci = rhs_finish! (state, ci, key, world, settings, 1 )
46+ const SCIML_ABI = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase. NullParameters, Float64}
8547
86- # Create a small opaque closure to adapt from SciML ABI to our own internal
87- # ABI
48+ function sciml_to_internal_abi! (ir :: IRCode , state :: TransformationState , internal_ci :: CodeInstance , key :: TornCacheKey , var_eq_matching, settings :: Settings )
49+ (; result, structure) = state
8850
8951 numstates = zeros (Int, Int (LastEquationStateKind))
90-
91- all_states = Int[]
92- for var = 1 : length (result. var_to_diff)
93- varkind (state, var) == Intrinsics. Continuous || continue
52+ for var in continuous_variables (state)
9453 kind = classify_var (result. var_to_diff, key, var)
9554 kind == nothing && continue
9655 numstates[kind] += 1
97- (kind != AlgebraicDerivative) && push! (all_states, var)
9856 end
9957
100- ir_oc = copy (ci. inferred. ir)
101- empty! (ir_oc. argtypes)
102- push! (ir_oc. argtypes, Tuple)
103- push! (ir_oc. argtypes, Vector{Float64})
104- push! (ir_oc. argtypes, Vector{Float64})
105- push! (ir_oc. argtypes, Vector{Float64})
106- push! (ir_oc. argtypes, SciMLBase. NullParameters)
107- push! (ir_oc. argtypes, Float64)
58+ empty! (ir. argtypes)
59+ push! (ir. argtypes, Tuple) # opaque closure captures
60+ append! (ir. argtypes, fieldtypes (SCIML_ABI))
10861
109- oc_compact = IncrementalCompact (ir_oc )
62+ compact = IncrementalCompact (ir )
11063
11164 # Zero the output
112- line = ir_oc [SSAValue (1 )][:line ]
113- @insert_instruction_here oc_compact line settings zero! (Argument (2 )):: VectorViewType
65+ line = ir [SSAValue (1 )][:line ]
66+ @insert_instruction_here compact line settings zero! (Argument (2 )):: VectorViewType
11467
11568 # out_du_mm, out_eq, in_u_mm, in_u_unassgn, in_du_unassgn, in_alg
11669 nassgn = numstates[AssignedDiff]
11770 ntotalstates = numstates[AssignedDiff] + numstates[UnassignedDiff] + numstates[Algebraic]
118- out_du_mm = @insert_instruction_here oc_compact line settings view (Argument (2 ), 1 : nassgn):: VectorViewType
119- out_eq = @insert_instruction_here oc_compact line settings view (Argument (2 ), (nassgn+ 1 ): ntotalstates):: VectorViewType
71+ out_du_mm = @insert_instruction_here compact line settings view (Argument (2 ), 1 : nassgn):: VectorViewType
72+ out_eq = @insert_instruction_here compact line settings view (Argument (2 ), (nassgn+ 1 ): ntotalstates):: VectorViewType
12073
121- (in_du_assgn, in_du_unassgn) = sciml_dae_split_du! (oc_compact , line, settings, Argument (3 ), numstates)
122- (in_u_mm, in_u_unassgn, in_alg) = sciml_dae_split_u! (oc_compact , line, settings, Argument (4 ), numstates)
74+ (in_du_assgn, in_du_unassgn) = sciml_dae_split_du! (compact , line, settings, Argument (3 ), numstates)
75+ (in_u_mm, in_u_unassgn, in_alg) = sciml_dae_split_u! (compact , line, settings, Argument (4 ), numstates)
12376
12477 # Call DAECompiler-generated RHS with internal ABI
125- oc_sicm = @insert_instruction_here oc_compact line settings getfield (Argument (1 ), 1 ):: Core.OpaqueClosure
78+ oc_sicm = @insert_instruction_here compact line settings getfield (Argument (1 ), 1 ):: Core.OpaqueClosure
12679
12780 # N.B: The ordering of arguments should match the ordering in the StateKind enum
128- @insert_instruction_here oc_compact line settings (:invoke )(daef_ci, oc_sicm, (), in_u_mm, in_u_unassgn, in_du_unassgn, in_alg, out_du_mm, out_eq, Argument (6 )):: Nothing
129-
130- # TODO : We should not have to recompute this here
131- var_eq_matching = matching_for_key (state, key)
132- (slot_assignments, var_assignment, eq_assignment) = assign_slots (state, key, var_eq_matching)
81+ @insert_instruction_here compact line settings (:invoke )(internal_ci, oc_sicm, (), in_u_mm, in_u_unassgn, in_du_unassgn, in_alg, out_du_mm, out_eq, Argument (6 )):: Nothing
13382
13483 # Manually apply mass matrix and implicit equations between selected states
135- for v = 1 : ndsts (state. structure. graph)
136- vdiff = state. structure. var_to_diff[v]
84+ (_, var_assignment, _) = assign_slots (state, key, var_eq_matching)
85+ for v = 1 : ndsts (structure. graph)
86+ vdiff = structure. var_to_diff[v]
13787 vdiff === nothing && continue
13888
13989 if var_eq_matching[v] != = SelectedState () || var_eq_matching[vdiff] != = SelectedState ()
@@ -146,22 +96,81 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
14696 @assert kind == AssignedDiff
14797 @assert dkind in (AssignedDiff, UnassignedDiff)
14898
149- v_val = @insert_instruction_here oc_compact line settings getindex (dkind == AssignedDiff ? in_u_mm : in_u_unassgn, dslot):: Any
150- @insert_instruction_here oc_compact line settings setindex! (out_du_mm, v_val, slot):: Any
99+ v_val = @insert_instruction_here compact line settings getindex (dkind == AssignedDiff ? in_u_mm : in_u_unassgn, dslot):: Any
100+ @insert_instruction_here compact line settings setindex! (out_du_mm, v_val, slot):: Any
151101 end
152102
153- bc = @insert_instruction_here oc_compact line settings Base. Broadcast. broadcasted (- , out_du_mm, in_du_assgn):: Any
154- @insert_instruction_here oc_compact line settings Base. Broadcast. materialize! (out_du_mm, bc):: Nothing
103+ bc = @insert_instruction_here compact line settings Base. Broadcast. broadcasted (- , out_du_mm, in_du_assgn):: Any
104+ @insert_instruction_here compact line settings Base. Broadcast. materialize! (out_du_mm, bc):: Nothing
155105
156106 # Return
157- @insert_instruction_here oc_compact line settings (return nothing ):: Union{}
107+ @insert_instruction_here compact line settings (return nothing ):: Union{}
158108
159- ir_oc = Compiler. finish (oc_compact)
160- maybe_rewrite_debuginfo! (ir_oc, settings)
161- resize! (ir_oc. cfg. blocks, 1 )
162- empty! (ir_oc. cfg. blocks[1 ]. succs)
163- Compiler. verify_ir (ir_oc)
164- oc = Core. OpaqueClosure (ir_oc)
109+ ir = Compiler. finish (compact)
110+ maybe_rewrite_debuginfo! (ir, settings)
111+ resize! (ir. cfg. blocks, 1 )
112+ empty! (ir. cfg. blocks[1 ]. succs)
113+ Compiler. verify_ir (ir)
114+
115+ @async @eval Main begin
116+ interface_ir = $ ir
117+ end
118+
119+ return Core. OpaqueClosure (ir; slotnames = [:captures , :out , :du , :u , :p , :t ])
120+ end
121+
122+ """
123+ dae_factory_gen(ci, key)
124+
125+ Generate the `factory` function for CodeInstance `ci`, returning a DAEFunction.
126+ The resulting function is roughly:
127+
128+ ```
129+ function factory(settings, f)
130+ # Run all parts of `f` that do not depend on state
131+ state_invariant_pieces = f_state_invariant()
132+ f! = %new_opaque_closure(f_rhs, state_invariant_pieces)
133+ DAEFunction(f!), differential_vars
134+ end
135+ ```
136+
137+ """
138+ function dae_factory_gen (state:: TransformationState , ci:: CodeInstance , key:: TornCacheKey , world:: UInt , settings:: Settings , init_key:: Union{TornCacheKey, Nothing} )
139+ result = state. result
140+ # TODO : We should not have to recompute this here
141+
142+ ir_factory = copy (ci. inferred. ir)
143+ pushfirst! (ir_factory. argtypes, Settings)
144+ pushfirst! (ir_factory. argtypes, typeof (factory))
145+ compact = IncrementalCompact (ir_factory)
146+
147+ # Create a small opaque closure to adapt from SciML ABI to our own internal ABI
148+ argt = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase. NullParameters, Float64}
149+ sicm = ()
150+ if settings. skip_optimizations
151+ daef_ci = rhs_finish_noopt! (state, ci, key, world, settings, 1 )
152+ oc = sciml_to_internal_abi_noopt! (copy (ci. inferred. ir), state, daef_ci, settings)
153+ else
154+ var_eq_matching = matching_for_key (state, key)
155+
156+ torn_ci = find_matching_ci (ci-> isa (ci. owner, TornIRSpec) && ci. owner. key == key, ci. def, world)
157+ torn_ir = torn_ci. inferred
158+
159+ (; ir_sicm) = torn_ir
160+
161+ local line
162+ if ir_sicm != = nothing
163+ sicm_ci = find_matching_ci (ci-> isa (ci. owner, SICMSpec) && ci. owner. key == key, ci. def, world)
164+ @assert sicm_ci != = nothing
165+
166+ line = result. ir[SSAValue (1 )][:line ]
167+ param_list = flatten_parameter! (Compiler. fallback_lattice, compact, ci. inferred. ir. argtypes[1 : end ], argn-> Argument (2 + argn), line, settings)
168+ sicm = @insert_instruction_here compact line settings invoke (param_list, sicm_ci):: Tuple
169+ end
170+
171+ daef_ci = rhs_finish! (state, ci, key, world, settings, 1 )
172+ oc = sciml_to_internal_abi! (copy (ci. inferred. ir), state, daef_ci, key, var_eq_matching, settings)
173+ end
165174
166175 line = result. ir[SSAValue (1 )][:line ]
167176
@@ -173,6 +182,7 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
173182
174183 new_oc = @insert_instruction_here compact line settings (:new_opaque_closure )(argt, Union{}, Nothing, true , oc_source_method, sicm):: Core.OpaqueClosure true
175184
185+ all_states = filter (var -> classify_var (result, key, var) != AlgebraicDerivative, continuous_variables (state))
176186 differential_states = Bool[v in key. diff_states for v in all_states]
177187
178188 if init_key != = nothing
@@ -192,6 +202,6 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
192202 empty! (ir_factory. cfg. blocks[1 ]. succs)
193203 Compiler. verify_ir (ir_factory)
194204
195- slotnames = [[ :factory , :settings ]; Symbol .( :arg , 1 : ( length (ir_factory . argtypes) - 2 )) ]
205+ slotnames = [:factory , :settings , :f ]
196206 return ir_factory, slotnames
197207end
0 commit comments