@@ -6,67 +6,90 @@ function rhs_finish_noopt!(
66 key:: UnoptimizedKey ,
77 world:: UInt ,
88 settings:: Settings ,
9- indexT= Int)
9+ equation_to_residual_mapping = 1 : length (state. structure. eq_to_diff),
10+ variable_to_state_mapping = map_variables_to_states (state))
1011
1112 (; result, structure) = state
12- result_ci = find_matching_ci (ci -> ci. inferred === key, ci. def, world)
13+ result_ci = find_matching_ci (ci -> ci. owner === key, ci. def, world)
1314 if result_ci != = nothing
1415 return result_ci
1516 end
1617
1718 ir = copy (result. ir)
18- slotnames = [:captures , :vars , :out , :du , :u , :out_indices , :du_indices , :u_indices , :t ]
19- argtypes = [Tuple, Vector{Float64}, Vector{Float64}, Vector{Float64}, Vector{Float64}, VectorIntViewType, VectorIntViewType, VectorIntViewType, Float64]
19+ # TODO : use original function arguments too
20+ slotnames = [:captures , :out , :du , :u , :residuals , :states , :t ]
21+ argtypes = [Tuple, Vector{Float64}, Vector{Float64}, Vector{Float64}, Vector{Int}, Vector{Int}, Float64]
2022 append! (empty! (ir. argtypes), argtypes)
21- captures, vars, out, du, u, out_indices, du_indices, u_indices, t = Argument .(eachindex (slotnames))
22- # TODO : use `out_indices`, `du_indices`, `u_indices`
23+ captures, out, du, u, residuals, states, t = Argument .(eachindex (slotnames))
2324 @assert length (slotnames) == length (ir. argtypes)
2425
2526 equations = Pair{SSAValue, Eq}[]
27+ callee_to_caller_eq_map = invert_eq_callee_mapping (result. eq_callee_mapping)
2628 compact = IncrementalCompact (ir)
2729
28- for ((_, i), _) in compact
30+ # @sshow equation_to_residual_mapping variable_to_state_mapping
31+
32+ for ((old, i), _) in compact
2933 ssaidx = SSAValue (i)
3034 inst = compact[ssaidx]
3135 stmt = inst[:stmt ]
3236 type = inst[:type ]
3337 line = inst[:line ]
3438
35- if is_known_invoke (stmt, Intrinsics. equation, compact)
39+ if i == 1
40+ # @insert_instruction_here(compact, nothing, settings, println("Residuals: ", residuals)::Any)
41+ # @insert_instruction_here(compact, nothing, settings, println("States: ", states)::Any)
42+ replace_uses! (compact, (old, inst) => SSAValue (3 ))
43+ end
44+
45+ if is_known_invoke_or_call (stmt, Intrinsics. variable, compact)
46+ var = idnum (type)
47+ index = @insert_instruction_here (compact, line, settings, getindex (states, var):: Int )
48+ value = @insert_instruction_here (compact, line, settings, getindex (u, index):: Float64 )
49+ replace_uses! (compact, (old, inst) => value)
50+ # @insert_instruction_here(compact, line, settings, println("Variable (", var, "): ", value)::Float64)
51+ elseif is_known_invoke (stmt, Intrinsics. ddt, compact)
52+ var = idnum (type)
53+ index = @insert_instruction_here (compact, line, settings, getindex (states, var):: Int )
54+ value = @insert_instruction_here (compact, line, settings, getindex (du, index):: Float64 )
55+ replace_uses! (compact, (old, inst) => value)
56+ # @insert_instruction_here(compact, line, settings, println("Variable derivative (", var, " := ", invview(structure.var_to_diff)[var], "′): ", value)::Any)
57+ elseif is_known_invoke (stmt, Intrinsics. equation, compact)
3658 push! (equations, ssaidx => type:: Eq )
3759 inst[:stmt ] = nothing
38- elseif is_known_invoke (stmt, Intrinsics. ddt, compact)
39- var = invview (structure. var_to_diff)[idnum (type)]
40- getdu = Expr (:call , getindex, du, var)
41- replace_call! (compact, ssaidx, getdu, settings, @__SOURCE__ )
42- inst[:type ] = Float64
4360 elseif is_equation_call (stmt, compact)
4461 callee, value = stmt. args[2 ], stmt. args[3 ]
4562 i = findfirst (x -> first (x) == callee, equations):: Int
4663 eq = last (equations[i])
47- call = Expr (:call , setindex!, out, value, eq. id)
48- replace_call! (compact, ssaidx, call, settings, @__SOURCE__ )
49- elseif is_known_invoke_or_call (stmt, Intrinsics. variable, compact)
50- var = idnum (type)
51- call = Expr (:call , getindex, u, var)
52- replace_call! (compact, ssaidx, call, settings, @__SOURCE__ )
53- inst[:type ] = Float64
64+ index = @insert_instruction_here (compact, line, settings, getindex (residuals, eq. id):: Int )
65+ ret = @insert_instruction_here (compact, line, settings, setindex! (out, value, index):: Any )
66+ replace_uses! (compact, (old, inst) => ret)
67+ # @insert_instruction_here(compact, line, settings, println("Residuals (index = ", index, ", value = ", value, "): ", residuals)::Any)
5468 elseif is_known_invoke_or_call (stmt, Intrinsics. sim_time, compact)
5569 inst[:stmt ] = t
5670 elseif is_known_invoke_or_call (stmt, Intrinsics. epsilon, compact)
5771 inst[:stmt ] = 0.0
5872 elseif isexpr (stmt, :invoke )
59- @sshow stmt
73+ info = inst[ :info ] :: MappingInfo
6074 callee_ci, callee_f = stmt. args[1 ]:: CodeInstance , stmt. args[2 ]
6175 callee_result = structural_analysis! (callee_ci, world, settings)
6276 callee_structure = make_structure_from_ipo (callee_result)
6377 callee_state = TransformationState (callee_result, callee_structure)
64- callee_daef_ci = rhs_finish_noopt! (callee_state, callee_ci, UnoptimizedKey (), world, settings)
65- callee_captures = ()
66- # TODO : compute indices into `u`/`du`/`out`
67- empty! (stmt. args)
68- push! (stmt. args, callee_daef_ci, callee_captures, vars,
69- out, du, u, out_indices, du_indices, u_indices, t)
78+
79+ callee_residuals = equation_to_residual_mapping[callee_to_caller_eq_map[StructuralSSARef (old)]]
80+ caller_variables = idnum .(info. mapping. var_coeffs)
81+ callee_states = variable_to_state_mapping[caller_variables]
82+
83+ callee_daef_ci = rhs_finish_noopt! (callee_state, callee_ci, UnoptimizedKey (), world, settings, callee_residuals, callee_states)
84+ call = @insert_instruction_here (compact, line, settings, (:invoke )(callee_daef_ci, callee_f,
85+ out,
86+ du,
87+ u,
88+ @insert_instruction_here (compact, line, settings, getindex (Int, callee_residuals... ):: Vector{Int} ),
89+ @insert_instruction_here (compact, line, settings, getindex (Int, callee_states... ):: Vector{Int} ),
90+ t):: type )
91+ # TODO : add `stmt.args[3:end]`
92+ replace_uses! (compact, (old, inst) => call)
7093 end
7194 type = inst[:type ]
7295 if isa (type, Incidence) || isa (type, Eq)
@@ -75,9 +98,28 @@ function rhs_finish_noopt!(
7598 end
7699
77100 daef_ci = rhs_finish_ir! (Compiler. finish (compact), ci, settings, key, slotnames)
101+ # @sshow daef_ci.inferred
78102 return daef_ci
79103end
80104
105+ function map_variables_to_states (state:: TransformationState )
106+ (; structure) = state
107+ diff_to_var = invview (structure. var_to_diff)
108+ states = Int[]
109+ prev_state = 0
110+ for var in continuous_variables (state)
111+ ref = is_differential_variable (structure, var) ? diff_to_var[var] : var
112+ state = @something (get (states, ref, nothing ), prev_state += 1 )
113+ push! (states, state)
114+ end
115+ return states
116+ end
117+
118+ function replace_uses! (compact, ((old, inst), new))
119+ inst[:stmt ] = nothing
120+ compact. ssa_rename[old] = new
121+ end
122+
81123function sciml_to_internal_abi_noopt! (ir:: IRCode , state:: TransformationState , internal_ci:: CodeInstance , settings:: Settings )
82124 slotnames = [:captures , :out , :du , :u , :p , :t ]
83125 captures, out, du, u, p, t = Argument .(eachindex (slotnames))
@@ -90,14 +132,11 @@ function sciml_to_internal_abi_noopt!(ir::IRCode, state::TransformationState, in
90132 line = ir[SSAValue (1 )][:line ]
91133
92134 internal_oc = @insert_instruction_here (compact, line, settings, getfield (captures, 1 ):: Core.OpaqueClosure )
93- # TODO : Compute proper indices.
94135 neqs = length (state. structure. eq_to_diff)
95- out_indices = @insert_instruction_here (compact, line, settings, view (out, 1 : neqs):: VectorIntViewType )
96- du_indices = @insert_instruction_here (compact, line, settings, view (du, 1 : neqs):: VectorIntViewType )
97- u_indices = @insert_instruction_here (compact, line, settings, view (u, 1 : neqs):: VectorIntViewType )
98- # TODO : Provide actual external variables.
99- vars = @insert_instruction_here (compact, line, settings, getindex (Float64):: Vector{Float64} )
100- @insert_instruction_here (compact, line, settings, (:invoke )(internal_ci, internal_oc, vars, out, du, u, out_indices, du_indices, u_indices, t):: Nothing )
136+ nvars = length (state. structure. var_to_diff)
137+ residuals = @insert_instruction_here (compact, line, settings, getindex (Int, 1 : neqs... ):: Vector{Int} )
138+ states = @insert_instruction_here (compact, line, settings, getindex (Int, map_variables_to_states (state)... ):: Vector{Int} )
139+ @insert_instruction_here (compact, line, settings, (:invoke )(internal_ci, internal_oc, out, du, u, residuals, states, t):: Nothing )
101140 @insert_instruction_here (compact, line, settings, (return nothing ):: Union{} )
102141
103142 ir = Compiler. finish (compact)
0 commit comments