1+ function tearing_visit_custom! (ir:: IRCode , ssa:: Union{SSAValue,Argument} , order, recurse)
2+ if isa (ssa, Argument)
3+ return false
4+ end
5+
6+ stmt = ir[ssa][:inst ]
7+ if is_known_invoke_or_call (stmt, variable, ir)
8+ return true
9+ elseif is_known_invoke (stmt, equation, ir)
10+ return true
11+ elseif is_known_invoke (stmt, sim_time, ir)
12+ return true
13+ elseif is_equation_call (stmt, ir)
14+ recurse (_eq_function_arg (stmt))
15+ recurse (_eq_val_arg (stmt))
16+ return true
17+ elseif is_known_invoke (stmt, ddt, ir)
18+ recurse (stmt. args[end ], order+ 1 )
19+ return true
20+ end
21+
22+ if isa (stmt, PhiNode)
23+ # Don't run our custom transform for PhiNodes - we don't have a place
24+ # to put the call and the regular recursion will handle it fine.
25+ return false
26+ end
27+
28+ typ = ir[ssa][:type ]
29+ has_simple_incidence_info (typ) || return false
30+
31+ # we have custom handling for things without any dependency on time nor state
32+ return ! has_dependence (typ)
33+ end
34+
35+ function is_diffed_equation_call_invoke_or_call (@nospecialize (stmt), ir:: IRCode )
36+ (isexpr (stmt, :invoke ) || isexpr (stmt, :call )) || return false
37+ callee = _eq_function_arg (stmt)
38+ isa (callee, SSAValue) || return false
39+ bundlecall = ir[callee][:inst ]
40+ isexpr (bundlecall, :call ) || return false
41+ bt = bundlecall. args[1 ]
42+ isa (bt, Type) || return false
43+ bt <: Diffractor.TaylorBundle || return false
44+ ft = argextype (bundlecall. args[2 ], ir)
45+ return widenconst (ft) === equation
46+ end
47+
48+ function index_lowering_ad! (state:: TransformationState , key:: TornCacheKey )
49+ (; result, structure) = state
50+ (; var_to_diff, eq_to_diff, graph, solvable_graph) = structure
51+
52+ ir = state. result. ir
53+
54+ # Figure out which equations we need to differentiate
55+ # TODO : Should have some nicer interface in MTK
56+ diff_eqs = Pair{Int, Int}[]
57+ for i = 1 : length (eq_to_diff)
58+ # If this is a linear equation, we cannot differentiate it, because
59+ # alias elimination changed the equation on us, but didn't update the
60+ # IR. We codegen it directly below.
61+ if invview (eq_to_diff)[i] === nothing && eq_to_diff[i] != = nothing && ! isempty (𝑠neighbors (graph, eq_to_diff[i]))
62+ level = 1
63+ diff = eq_to_diff[i]
64+ is_fully_state_linear (result. total_incidence[i], key. param_vars) && continue
65+ while (diff = eq_to_diff[diff]) != = nothing
66+ level += 1
67+ end
68+ push! (diff_eqs, i => level)
69+ end
70+ end
71+
72+ # Mark all non-trivial `ddt()` statements as ones that we should differentiate
73+ diff_ssas = Pair{SSAValue,Int}[]
74+ for i = 1 : length (ir. stmts)
75+ if is_known_invoke (ir. stmts[i][:stmt ], ddt, ir) && ! is_const_plus_state_linear (argextype (ir. stmts[i][:stmt ]. args[end ], ir), key. param_vars)
76+ push! (diff_ssas, SSAValue (i) => 0 )
77+ end
78+ end
79+
80+ if isempty (diff_ssas) && isempty (diff_eqs)
81+ return copy (ir)
82+ end
83+
84+ compact = IncrementalCompact (copy (ir))
85+ # TODO : This could all be combined with the below into a single pass
86+ (eqs, vars) = find_eqs_vars (state. structure. graph, compact)
87+ ir = Compiler. finish (compact)
88+
89+ for (eq, level) in diff_eqs
90+ for ssa in eqs[eq][2 ]
91+ push! (diff_ssas, ssa => level)
92+ end
93+ end
94+
95+ append! (eqs, (SSAValue (0 )=> SSAValue[] for _ in 1 : (length (eq_to_diff)- length (eqs))))
96+ append! (vars, fill (SSAValue (0 ), length (var_to_diff)- length (vars)))
97+ domtree = Compiler. construct_domtree (ir. cfg. blocks)
98+
99+ function diff_one! (ir, ssa, dvar)
100+ if dvar === nothing
101+ # dvar can be `nothing` if we are differentiating a variable that doesn't actually appear
102+ # in the matched system structure's incidence analysis for the equation currently being differentiated.
103+ # This can occur because Diffractor's types bundle both the primal and the tangent derivatives
104+ # in a single type, causing differentiation of all listed variables to hit this function.
105+ # We emit here a `_DIFF_UNUSED` value that we expect to never be used and DCE'd later on in the pipeline.
106+ return insert_node! (ir, ssa, NewInstruction (GlobalRef (DAECompiler. Intrinsics, :_DIFF_UNUSED ), Incidence (Float64), Int32 (1 )))
107+ end
108+ if vars[dvar] == SSAValue (0 )
109+ vars[dvar] = insert_node! (ir, ssa, NewInstruction (Expr (:invoke , nothing , variable), Incidence (dvar)))
110+ elseif ! dominates_ssa (ir, domtree, vars[dvar], ssa; dominates_after= true )
111+ varssa = vars[dvar]
112+ inst = ir[varssa]
113+ vars[dvar] = insert_node! (ir, ssa, NewInstruction (inst))
114+ ir[varssa][:inst ] = vars[dvar]
115+ end
116+ return vars[dvar]
117+ end
118+
119+ function diff_variable! (ir, ssa, stmt, order)
120+ inst = ir[ssa]
121+ var = idnum (ir[ssa][:type ])
122+ primal = insert_node! (ir, ssa, NewInstruction (inst))
123+ vars[var] = primal
124+ diffs = SSAValue[]
125+ for i = 1 : order
126+ var != = nothing && (var = var_to_diff[var])
127+ push! (diffs, diff_one! (ir, ssa, var))
128+ end
129+ duals = insert_node! (ir, ssa, NewInstruction (
130+ Expr (:call , tuple, diffs... ), Any
131+ ))
132+ replace_call! (ir, ssa, Expr (:call , Diffractor. TaylorBundle{order}, primal, duals))
133+ end
134+
135+ function transform! (ir, ssa, order, maparg)
136+ if isa (ssa, Argument)
137+ # at start of function define a SSA holding the initially accumulated derivative of each argument, (i.e. 0)
138+ return insert_node! (ir, SSAValue (1 ), NewInstruction (Expr (:call , Diffractor. zero_bundle {order} (), ssa), Any))
139+ end
140+ inst = ir[ssa]
141+ stmt = inst[:inst ]
142+ while isa (stmt, SSAValue)
143+ # It's possible an earlier call to transform! moved this call, so follow references.
144+ stmt = ir[stmt][:inst ]
145+ end
146+ if is_known_invoke (stmt, variable, ir)
147+ diff_variable! (ir, ssa, stmt, order)
148+ return nothing
149+ elseif is_known_invoke (stmt, equation, ir)
150+ eq = inst[:type ]. id
151+ primal = insert_node! (ir, ssa, NewInstruction (inst))
152+ eqs[eq] = primal=> eqs[eq][2 ]
153+ duals = SSAValue[]
154+ for _ = 1 : order
155+ deq = eq_to_diff[eq]
156+ # If `deq` is nothing, that means we're asking for a derivative of an equation
157+ # that does not exist. This is possible if we, for instance, have a tuple of
158+ # equation-related values that does not get SROA'ed, and is then differentiated
159+ # by Diffractor due to _one_ of the equations being differentiated. But that
160+ # results in this loop asking for derivatives of the _other_ equations that
161+ # don't exist. To handle this, we insert a bogus equation node, similar in
162+ # spirit to the `_DIFF_UNUSED` value.
163+ if deq === nothing
164+ diff = insert_node! (ir, ssa, NewInstruction (GlobalRef (DAECompiler. Intrinsics, :_EQ_UNUSED ), equation))
165+ else
166+ diff = insert_node! (ir, ssa, NewInstruction (inst))
167+ diffinst = ir[diff]
168+ diffinst[:type ] = Eq (deq)
169+ eqs[deq] = diff=> eqs[deq][2 ]
170+ eq = deq
171+ end
172+ push! (duals, diff)
173+ end
174+ dtup = insert_node! (ir, ssa, NewInstruction (
175+ Expr (:call , tuple, duals... ), Any
176+ ))
177+ # N.B.: No replace_call!, because we rely on the type of this call.
178+ inst[:inst ] = Expr (:call , Diffractor. TaylorBundle{order}, primal, dtup)
179+ inst[:info ] = Compiler. NoCallInfo ()
180+ return nothing
181+ elseif is_known_invoke (stmt, sim_time, ir)
182+ time = insert_node! (ir, ssa, NewInstruction (inst))
183+ replace_call! (ir, ssa, Expr (:call , Diffractor.∂xⁿ {order} (), time))
184+ return nothing
185+ elseif is_diffed_equation_call_invoke_or_call (stmt, ir)
186+ eq = idnum (argextype (_eq_function_arg (stmt), ir))
187+ bundle = _eq_val_arg (stmt)
188+ # Rewrite the equation (we could extract it from the bundle, but we already know where it is)
189+ # N.B.: We don't need replace_call! here, because we're not changing the call target,
190+ # we're just rearranging the SSA.
191+ inst[:inst ] = Expr (
192+ :call ,
193+ eqs[eq][1 ],
194+ insert_node! (ir, ssa, NewInstruction (Expr (:call , getfield, bundle, 1 ), Any)), # primal
195+ )
196+ # Pull out the equation from the primal, so we can null it out below
197+ new_primal = insert_node! (ir, ssa, NewInstruction (inst))
198+ replace! (eqs[eq][2 ], ssa=> new_primal)
199+ for i = 1 : order
200+ val = insert_node! (ir, ssa, NewInstruction (Expr (:call , getindex, bundle, Diffractor. TaylorTangentIndex (i)), Any))
201+ push! (
202+ eqs[eq_to_diff[eq]][2 ],
203+ insert_node! (ir, ssa, NewInstruction (Expr (:call , eqs[eq_to_diff[eq]][1 ], val), Any))
204+ )
205+ eq = eq_to_diff[eq]
206+ end
207+ # equation! also returns nothing, but it's possible for the value
208+ # to be used (e.g. by a return, so conform to the interface)
209+ dnullout_inst! (inst, order)
210+ elseif is_known_invoke (stmt, ddt, ir)
211+ arg = maparg (stmt. args[end ], ssa, order+ 1 )
212+ if order == 0
213+ replace_call! (ir, ssa, Expr (:call , Diffractor. partial, arg, 1 ))
214+ else
215+ replace_call! (ir, ssa, Expr (:call , diff_bundle, arg))
216+ end
217+ return nothing
218+ else
219+ # must be something with no dependency
220+ @assert ! has_dependence (inst[:type ])
221+ urs = userefs (stmt)
222+ for ur in urs
223+ ur[] = maparg (ur[], ssa, 0 )
224+ end
225+ inst[:inst ] = urs[]
226+ primal = insert_node! (ir, ssa, NewInstruction (inst))
227+ replace_call! (ir, ssa, Expr (:call , Diffractor. zero_bundle {order} (), primal))
228+ return nothing
229+ end
230+ end
231+ Diffractor. forward_diff_no_inf! (ir, diff_ssas; visit_custom! = tearing_visit_custom!, transform!, eras_mode= true )
232+
233+ # Rename state
234+ compact = IncrementalCompact (ir)
235+ (eqs, vars) = find_eqs_vars (state. structure. graph, compact)
236+ # Some variables may look dead, but are used in linear equations
237+ # don't dce them just yet - we'll dce them below
238+ Compiler. non_dce_finish! (compact)
239+ ir = Compiler. complete (compact)
240+
241+ # Derivatives can appear out of "thin air" due to implicit dependencies
242+ # (i.e. an equation that depends on 1 also depends on ddt(1)), or due to
243+ # imprecision introduced by the AD transform (causing a primal to
244+ # spuriously be carried along in the Incidence with its derivative).
245+ #
246+ # Allow this by verifying there is an element in `g` whose k-derivative
247+ # is `var` (k ∈ ℤ).
248+ function in_any_derivative (var, g)
249+ while var_to_diff[var] != = nothing
250+ var = var_to_diff[var] # Normalize to highest-derivative
251+ end
252+ while true
253+ var in g && return true
254+ invview (var_to_diff)[var] === nothing && return false
255+ var = invview (var_to_diff)[var]
256+ end
257+ end
258+
259+ # Update solvable graph
260+ #=
261+ for (eq, (_, eqssas)) in enumerate(eqs)
262+ is_fully_state_linear(state.total_incidence[eq], key.param_vars) && continue
263+ old_graph = empty_eq_list!(graph, eq)
264+ old_solvable_graph = empty_eq_list!(solvable_graph, eq)
265+ for eqssa in eqssas
266+ if ir[eqssa][:inst] === nothing
267+ # Could have been in a dead branch and deleted - allow that for now.
268+ continue
269+ end
270+ eqssaval = _eq_val_arg(ir[eqssa][:inst])
271+ inc = ir[eqssaval][:type]
272+ if !isa(inc, Incidence)
273+ throw(UnsupportedIRException("Expected incidence analysis to produce result for $eqssaval, got $inc", ir))
274+ end
275+ for (v, coeff) in zip(rowvals(inc.row), nonzeros(inc.row))
276+ v == 1 && continue
277+ @assert in_any_derivative(v-1, old_graph)
278+ @assert !has_edge(graph, BipartiteEdge(eq, v-1))
279+ add_edge!(graph, eq, v-1)
280+ if coeff !== nonlinear
281+ add_edge!(solvable_graph, eq, v-1)
282+ else
283+ # TODO : solvable should generally not become unsolvable but in some cases
284+ # our AD transform widens Incidence propagation in a way that artificially
285+ # makes tearing's life harder (see downstream BSIM-CMG test)
286+ # @assert !(v-1 in old_solvable_graph)
287+ if v-1 in old_solvable_graph
288+ @debug "Variable $(v-1) in Eq. $(eq) went from solvable -> unsolvable after AD transform"
289+ end
290+ end
291+ end
292+ end
293+ end
294+ =#
295+
296+ return ir
297+ end
298+
299+ function empty_eq_list! (graph:: BipartiteGraph , eq)
300+ vs = copy (𝑠neighbors (graph, eq))
301+ foreach (vs) do v
302+ rem_edge! (graph, eq, v)
303+ end
304+ return vs
305+ end
0 commit comments