Skip to content

Commit 34876d7

Browse files
committed
Merge branch 'main' of github.com:CedarEDA/DAECompiler.jl into ode-codegen
2 parents 6d596cc + 1394232 commit 34876d7

File tree

9 files changed

+378
-54
lines changed

9 files changed

+378
-54
lines changed

src/DAECompiler.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ module DAECompiler
3131
include("transform/codegen/init_factory.jl")
3232
include("transform/codegen/rhs.jl")
3333
include("transform/codegen/init_uncompress.jl")
34+
include("transform/autodiff/ad_common.jl")
35+
include("transform/autodiff/ad_runtime.jl")
36+
include("transform/autodiff/index_lowering.jl")
3437
include("interface.jl")
3538
include("problem_interface.jl")
3639
end

src/analysis/refiner.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,12 @@ function structural_inc_ddt(var_to_diff::DiffGraph, varclassification::Union{Vec
9696
if isa(base, Const)
9797
if isa(coeff, Float64)
9898
base = Const(base.val + coeff)
99+
# Do not set r[v_offset]; d/dt t = 1
99100
else
100-
base = widenconst(base)
101+
r[v_offset] = nonlinear
101102
end
103+
elseif !isa(coeff, Const)
104+
r[v_offset] = nonlinear
102105
end
103106
continue
104107
end

src/interface.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ function factory_gen(world::UInt, source::Method, @nospecialize(_gen), settings,
4141
tstate = TransformationState(result, structure, copy(result.total_incidence))
4242
(diff_key, init_key) = top_level_state_selection!(tstate)
4343

44-
# TODO: Index lowering AD here
45-
4644
if settings.mode in (DAE, DAENoInit, ODE, ODENoInit)
4745
tearing_schedule!(tstate, ci, diff_key, world)
4846
end
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using Base.Meta
2+
using Base: quoted
3+
4+
"Set inst to be a differentiable nothing of given order"
5+
function dnullout_inst!(inst, order=1)
6+
inst[:inst] = quoted(Diffractor.DNEBundle{order}(nothing))
7+
inst[:type] = typeof(Diffractor.DNEBundle{order}(nothing))
8+
inst[:flag] = Compiler.IR_FLAG_EFFECT_FREE | Compiler.IR_FLAG_NOTHROW | Compiler.IR_FLAG_CONSISTENT
9+
return nothing
10+
end
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
function diff_bundle(bundle::Diffractor.UniformBundle{N, B, U}) where {N, B, U}
2+
return Diffractor.UniformBundle{N-1}(bundle.tangent.val, bundle.tangent)
3+
end
4+
5+
function diff_bundle(bundle::Diffractor.TaylorBundle{N}) where {N}
6+
return Diffractor.TaylorBundle{N-1}(bundle.tangent.coeffs[1], bundle.tangent.coeffs[2:end])
7+
end
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
function tearing_visit_custom!(ir::IRCode, ssa::Union{SSAValue,Argument}, order, recurse)
2+
if isa(ssa, Argument)
3+
return false
4+
end
5+
6+
stmt = ir[ssa][:inst]
7+
if is_known_invoke_or_call(stmt, variable, ir)
8+
return true
9+
elseif is_known_invoke(stmt, equation, ir)
10+
return true
11+
elseif is_known_invoke(stmt, sim_time, ir)
12+
return true
13+
elseif is_equation_call(stmt, ir)
14+
recurse(_eq_function_arg(stmt))
15+
recurse(_eq_val_arg(stmt))
16+
return true
17+
elseif is_known_invoke(stmt, ddt, ir)
18+
recurse(stmt.args[end], order+1)
19+
return true
20+
end
21+
22+
if isa(stmt, PhiNode)
23+
# Don't run our custom transform for PhiNodes - we don't have a place
24+
# to put the call and the regular recursion will handle it fine.
25+
return false
26+
end
27+
28+
typ = ir[ssa][:type]
29+
has_simple_incidence_info(typ) || return false
30+
31+
# we have custom handling for things without any dependency on time nor state
32+
return !has_dependence(typ)
33+
end
34+
35+
function is_diffed_equation_call_invoke_or_call(@nospecialize(stmt), ir::IRCode)
36+
(isexpr(stmt, :invoke) || isexpr(stmt, :call)) || return false
37+
callee = _eq_function_arg(stmt)
38+
isa(callee, SSAValue) || return false
39+
bundlecall = ir[callee][:inst]
40+
isexpr(bundlecall, :call) || return false
41+
bt = bundlecall.args[1]
42+
isa(bt, Type) || return false
43+
bt <: Diffractor.TaylorBundle || return false
44+
ft = argextype(bundlecall.args[2], ir)
45+
return widenconst(ft) === equation
46+
end
47+
48+
function index_lowering_ad!(state::TransformationState, key::TornCacheKey)
49+
(; result, structure) = state
50+
(; var_to_diff, eq_to_diff, graph, solvable_graph) = structure
51+
52+
ir = state.result.ir
53+
54+
# Figure out which equations we need to differentiate
55+
# TODO: Should have some nicer interface in MTK
56+
diff_eqs = Pair{Int, Int}[]
57+
for i = 1:length(eq_to_diff)
58+
# If this is a linear equation, we cannot differentiate it, because
59+
# alias elimination changed the equation on us, but didn't update the
60+
# IR. We codegen it directly below.
61+
if invview(eq_to_diff)[i] === nothing && eq_to_diff[i] !== nothing && !isempty(𝑠neighbors(graph, eq_to_diff[i]))
62+
level = 1
63+
diff = eq_to_diff[i]
64+
is_fully_state_linear(result.total_incidence[i], key.param_vars) && continue
65+
while (diff = eq_to_diff[diff]) !== nothing
66+
level += 1
67+
end
68+
push!(diff_eqs, i => level)
69+
end
70+
end
71+
72+
# Mark all non-trivial `ddt()` statements as ones that we should differentiate
73+
diff_ssas = Pair{SSAValue,Int}[]
74+
for i = 1:length(ir.stmts)
75+
if is_known_invoke(ir.stmts[i][:stmt], ddt, ir) && !is_const_plus_state_linear(argextype(ir.stmts[i][:stmt].args[end], ir), key.param_vars)
76+
push!(diff_ssas, SSAValue(i) => 0)
77+
end
78+
end
79+
80+
if isempty(diff_ssas) && isempty(diff_eqs)
81+
return copy(ir)
82+
end
83+
84+
compact = IncrementalCompact(copy(ir))
85+
# TODO: This could all be combined with the below into a single pass
86+
(eqs, vars) = find_eqs_vars(state.structure.graph, compact)
87+
ir = Compiler.finish(compact)
88+
89+
for (eq, level) in diff_eqs
90+
for ssa in eqs[eq][2]
91+
push!(diff_ssas, ssa => level)
92+
end
93+
end
94+
95+
append!(eqs, (SSAValue(0)=>SSAValue[] for _ in 1:(length(eq_to_diff)-length(eqs))))
96+
append!(vars, fill(SSAValue(0), length(var_to_diff)-length(vars)))
97+
domtree = Compiler.construct_domtree(ir.cfg.blocks)
98+
99+
function diff_one!(ir, ssa, dvar)
100+
if dvar === nothing
101+
# dvar can be `nothing` if we are differentiating a variable that doesn't actually appear
102+
# in the matched system structure's incidence analysis for the equation currently being differentiated.
103+
# This can occur because Diffractor's types bundle both the primal and the tangent derivatives
104+
# in a single type, causing differentiation of all listed variables to hit this function.
105+
# We emit here a `_DIFF_UNUSED` value that we expect to never be used and DCE'd later on in the pipeline.
106+
return insert_node!(ir, ssa, NewInstruction(GlobalRef(DAECompiler.Intrinsics, :_DIFF_UNUSED), Incidence(Float64), Int32(1)))
107+
end
108+
if vars[dvar] == SSAValue(0)
109+
vars[dvar] = insert_node!(ir, ssa, NewInstruction(Expr(:invoke, nothing, variable), Incidence(dvar)))
110+
elseif !dominates_ssa(ir, domtree, vars[dvar], ssa; dominates_after=true)
111+
varssa = vars[dvar]
112+
inst = ir[varssa]
113+
vars[dvar] = insert_node!(ir, ssa, NewInstruction(inst))
114+
ir[varssa][:inst] = vars[dvar]
115+
end
116+
return vars[dvar]
117+
end
118+
119+
function diff_variable!(ir, ssa, stmt, order)
120+
inst = ir[ssa]
121+
var = idnum(ir[ssa][:type])
122+
primal = insert_node!(ir, ssa, NewInstruction(inst))
123+
vars[var] = primal
124+
diffs = SSAValue[]
125+
for i = 1:order
126+
var !== nothing && (var = var_to_diff[var])
127+
push!(diffs, diff_one!(ir, ssa, var))
128+
end
129+
duals = insert_node!(ir, ssa, NewInstruction(
130+
Expr(:call, tuple, diffs...), Any
131+
))
132+
replace_call!(ir, ssa, Expr(:call, Diffractor.TaylorBundle{order}, primal, duals))
133+
end
134+
135+
function transform!(ir, ssa, order, maparg)
136+
if isa(ssa, Argument)
137+
# at start of function define a SSA holding the initially accumulated derivative of each argument, (i.e. 0)
138+
return insert_node!(ir, SSAValue(1), NewInstruction(Expr(:call, Diffractor.zero_bundle{order}(), ssa), Any))
139+
end
140+
inst = ir[ssa]
141+
stmt = inst[:inst]
142+
while isa(stmt, SSAValue)
143+
# It's possible an earlier call to transform! moved this call, so follow references.
144+
stmt = ir[stmt][:inst]
145+
end
146+
if is_known_invoke(stmt, variable, ir)
147+
diff_variable!(ir, ssa, stmt, order)
148+
return nothing
149+
elseif is_known_invoke(stmt, equation, ir)
150+
eq = inst[:type].id
151+
primal = insert_node!(ir, ssa, NewInstruction(inst))
152+
eqs[eq] = primal=>eqs[eq][2]
153+
duals = SSAValue[]
154+
for _ = 1:order
155+
deq = eq_to_diff[eq]
156+
# If `deq` is nothing, that means we're asking for a derivative of an equation
157+
# that does not exist. This is possible if we, for instance, have a tuple of
158+
# equation-related values that does not get SROA'ed, and is then differentiated
159+
# by Diffractor due to _one_ of the equations being differentiated. But that
160+
# results in this loop asking for derivatives of the _other_ equations that
161+
# don't exist. To handle this, we insert a bogus equation node, similar in
162+
# spirit to the `_DIFF_UNUSED` value.
163+
if deq === nothing
164+
diff = insert_node!(ir, ssa, NewInstruction(GlobalRef(DAECompiler.Intrinsics, :_EQ_UNUSED), equation))
165+
else
166+
diff = insert_node!(ir, ssa, NewInstruction(inst))
167+
diffinst = ir[diff]
168+
diffinst[:type] = Eq(deq)
169+
eqs[deq] = diff=>eqs[deq][2]
170+
eq = deq
171+
end
172+
push!(duals, diff)
173+
end
174+
dtup = insert_node!(ir, ssa, NewInstruction(
175+
Expr(:call, tuple, duals...), Any
176+
))
177+
# N.B.: No replace_call!, because we rely on the type of this call.
178+
inst[:inst] = Expr(:call, Diffractor.TaylorBundle{order}, primal, dtup)
179+
inst[:info] = Compiler.NoCallInfo()
180+
return nothing
181+
elseif is_known_invoke(stmt, sim_time, ir)
182+
time = insert_node!(ir, ssa, NewInstruction(inst))
183+
replace_call!(ir, ssa, Expr(:call, Diffractor.∂xⁿ{order}(), time))
184+
return nothing
185+
elseif is_diffed_equation_call_invoke_or_call(stmt, ir)
186+
eq = idnum(argextype(_eq_function_arg(stmt), ir))
187+
bundle = _eq_val_arg(stmt)
188+
# Rewrite the equation (we could extract it from the bundle, but we already know where it is)
189+
# N.B.: We don't need replace_call! here, because we're not changing the call target,
190+
# we're just rearranging the SSA.
191+
inst[:inst] = Expr(
192+
:call,
193+
eqs[eq][1],
194+
insert_node!(ir, ssa, NewInstruction(Expr(:call, getfield, bundle, 1), Any)), # primal
195+
)
196+
# Pull out the equation from the primal, so we can null it out below
197+
new_primal = insert_node!(ir, ssa, NewInstruction(inst))
198+
replace!(eqs[eq][2], ssa=>new_primal)
199+
for i = 1:order
200+
val = insert_node!(ir, ssa, NewInstruction(Expr(:call, getindex, bundle, Diffractor.TaylorTangentIndex(i)), Any))
201+
push!(
202+
eqs[eq_to_diff[eq]][2],
203+
insert_node!(ir, ssa, NewInstruction(Expr(:call, eqs[eq_to_diff[eq]][1], val), Any))
204+
)
205+
eq = eq_to_diff[eq]
206+
end
207+
# equation! also returns nothing, but it's possible for the value
208+
# to be used (e.g. by a return, so conform to the interface)
209+
dnullout_inst!(inst, order)
210+
elseif is_known_invoke(stmt, ddt, ir)
211+
arg = maparg(stmt.args[end], ssa, order+1)
212+
if order == 0
213+
replace_call!(ir, ssa, Expr(:call, Diffractor.partial, arg, 1))
214+
else
215+
replace_call!(ir, ssa, Expr(:call, diff_bundle, arg))
216+
end
217+
return nothing
218+
else
219+
# must be something with no dependency
220+
@assert !has_dependence(inst[:type])
221+
urs = userefs(stmt)
222+
for ur in urs
223+
ur[] = maparg(ur[], ssa, 0)
224+
end
225+
inst[:inst] = urs[]
226+
primal = insert_node!(ir, ssa, NewInstruction(inst))
227+
replace_call!(ir, ssa, Expr(:call, Diffractor.zero_bundle{order}(), primal))
228+
return nothing
229+
end
230+
end
231+
Diffractor.forward_diff_no_inf!(ir, diff_ssas; visit_custom! = tearing_visit_custom!, transform!, eras_mode=true)
232+
233+
# Rename state
234+
compact = IncrementalCompact(ir)
235+
(eqs, vars) = find_eqs_vars(state.structure.graph, compact)
236+
# Some variables may look dead, but are used in linear equations
237+
# don't dce them just yet - we'll dce them below
238+
Compiler.non_dce_finish!(compact)
239+
ir = Compiler.complete(compact)
240+
241+
# Derivatives can appear out of "thin air" due to implicit dependencies
242+
# (i.e. an equation that depends on 1 also depends on ddt(1)), or due to
243+
# imprecision introduced by the AD transform (causing a primal to
244+
# spuriously be carried along in the Incidence with its derivative).
245+
#
246+
# Allow this by verifying there is an element in `g` whose k-derivative
247+
# is `var` (k ∈ ℤ).
248+
function in_any_derivative(var, g)
249+
while var_to_diff[var] !== nothing
250+
var = var_to_diff[var] # Normalize to highest-derivative
251+
end
252+
while true
253+
var in g && return true
254+
invview(var_to_diff)[var] === nothing && return false
255+
var = invview(var_to_diff)[var]
256+
end
257+
end
258+
259+
# Update solvable graph
260+
#=
261+
for (eq, (_, eqssas)) in enumerate(eqs)
262+
is_fully_state_linear(state.total_incidence[eq], key.param_vars) && continue
263+
old_graph = empty_eq_list!(graph, eq)
264+
old_solvable_graph = empty_eq_list!(solvable_graph, eq)
265+
for eqssa in eqssas
266+
if ir[eqssa][:inst] === nothing
267+
# Could have been in a dead branch and deleted - allow that for now.
268+
continue
269+
end
270+
eqssaval = _eq_val_arg(ir[eqssa][:inst])
271+
inc = ir[eqssaval][:type]
272+
if !isa(inc, Incidence)
273+
throw(UnsupportedIRException("Expected incidence analysis to produce result for $eqssaval, got $inc", ir))
274+
end
275+
for (v, coeff) in zip(rowvals(inc.row), nonzeros(inc.row))
276+
v == 1 && continue
277+
@assert in_any_derivative(v-1, old_graph)
278+
@assert !has_edge(graph, BipartiteEdge(eq, v-1))
279+
add_edge!(graph, eq, v-1)
280+
if coeff !== nonlinear
281+
add_edge!(solvable_graph, eq, v-1)
282+
else
283+
# TODO: solvable should generally not become unsolvable but in some cases
284+
# our AD transform widens Incidence propagation in a way that artificially
285+
# makes tearing's life harder (see downstream BSIM-CMG test)
286+
# @assert !(v-1 in old_solvable_graph)
287+
if v-1 in old_solvable_graph
288+
@debug "Variable $(v-1) in Eq. $(eq) went from solvable -> unsolvable after AD transform"
289+
end
290+
end
291+
end
292+
end
293+
end
294+
=#
295+
296+
return ir
297+
end
298+
299+
function empty_eq_list!(graph::BipartiteGraph, eq)
300+
vs = copy(𝑠neighbors(graph, eq))
301+
foreach(vs) do v
302+
rem_edge!(graph, eq, v)
303+
end
304+
return vs
305+
end

src/transform/index_lowering_ad.jl

Lines changed: 0 additions & 28 deletions
This file was deleted.

0 commit comments

Comments
 (0)