Skip to content

Commit 1e697f7

Browse files
committed
add pretty_f function for inspecting codegen of mtk models
1 parent 1beec11 commit 1e697f7

File tree

3 files changed

+121
-1
lines changed

3 files changed

+121
-1
lines changed

docs/src/API.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ ff_to_constraint
253253
Base.copy(::NetworkDynamics.ComponentModel)
254254
extract_nw
255255
implicit_output
256+
NetworkDynamics.pretty_f
256257
```
257258

258259
## NetworkDynamicsInspector API

src/NetworkDynamics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ using Random: Random
2323
using Static: Static, StaticInt
2424
using SciMLBase: VectorContinuousCallback, CallbackSet, DiscreteCallback
2525
using DiffEqCallbacks: DiffEqCallbacks
26+
using MacroTools: postwalk, @capture
2627

2728
@static if VERSION v"1.11.0-0"
2829
using Base: AnnotatedIOBuffer, AnnotatedString
@@ -79,7 +80,6 @@ export DiscreteComponentCallback, PresetTimeComponentCallback
7980
export SymbolicView
8081
include("callbacks.jl")
8182

82-
using MacroTools: postwalk
8383
export @initconstraint, InitConstraint
8484
export @initformula, InitFormula
8585
include("init_constraints.jl")

src/utils.jl

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,122 @@ $implicit_output_docstring
249249
For more information see the NetworkDynamics docs on [fully implicit outputs](@ref Fully-Implicit-Outputs).
250250
"""
251251
implicit_output(x) = 0
252+
253+
254+
"""
255+
NetworkDynamics.pretty_f(v::VertexModel)
256+
257+
For debugging vertex models based off MTK, this function pretty prints the
258+
underlying generated function `f(du, u, in, p, t)` in a more readable way.
259+
"""
260+
function pretty_f(v)
261+
contains(repr(typeof(v.f)), "RuntimeGeneratedFunction") || throw(ArgumentError("pretty_mtk_function only works for MTK based functions"))
262+
rgf = v.f
263+
264+
renamings = Dict(
265+
:ˍ₋out => :du,
266+
:ˍ₋arg1 => :u,
267+
:ˍ₋arg2 => :in,
268+
:ˍ₋arg3 => :p,
269+
:ˍ₋arg4 => :t,
270+
)
271+
272+
body = Base.remove_linenums!(rgf.body)
273+
# some line number nodes are still there, lets remove them
274+
body = postwalk(x -> x isa LineNumberNode ? nothing : x, body)
275+
276+
allinbound = body.head == :macrocall && body.args[1] == Symbol("@inbounds")
277+
278+
inboundpass = if allinbound
279+
postwalk(body) do x
280+
if @capture(x, @inbounds xs_)
281+
xs
282+
else
283+
x
284+
end
285+
end
286+
else
287+
body
288+
end
289+
290+
# expand infix symbols
291+
infix = [+, *, ^, /, <, >, , ]
292+
expand_infix = postwalk(inboundpass) do x
293+
if x infix
294+
Symbol(x)
295+
else
296+
x
297+
end
298+
end
299+
300+
regex = r"^(.*)\(t\)$"
301+
diffregex = r"^Differential\(t\)\((.*)\(t\)\)$"
302+
renamed = postwalk(expand_infix) do x
303+
if x isa Symbol && haskey(renamings, x)
304+
renamings[x]
305+
elseif x isa Symbol && contains(string(x), regex)
306+
m = match(regex, string(x))
307+
Symbol(m.captures[1])
308+
elseif x isa Symbol && contains(string(x), diffregex)
309+
m = match(diffregex, string(x))
310+
Symbol("∂ₜ" * m.captures[1])
311+
else
312+
x
313+
end
314+
end
315+
316+
improve_juxtaposition = postwalk(renamed) do x
317+
if @capture(x, -1a_)
318+
:(- $a)
319+
else
320+
x
321+
end
322+
end
323+
324+
# now we want to expand our p names
325+
# p1, p2, ... = p
326+
uassigment = Expr(:(=), Expr(:tuple, sym(v)...), :u)
327+
inassigment = Expr(:(=), Expr(:tuple, insym(v)...), :in)
328+
passigment = Expr(:(=), Expr(:tuple, psym(v)...), :p)
329+
pushfirst!(improve_juxtaposition.args, uassigment)
330+
pushfirst!(improve_juxtaposition.args, inassigment)
331+
pushfirst!(improve_juxtaposition.args, passigment)
332+
resolved_uip = postwalk(improve_juxtaposition) do x
333+
if @capture(x, p[i_])
334+
psym(v)[i]
335+
elseif @capture(x, in[i_])
336+
insym(v)[i]
337+
elseif @capture(x, u[i_])
338+
sym(v)[i]
339+
else
340+
x
341+
end
342+
end
343+
344+
functionex = quote
345+
function $(Symbol("vmodel_"*string(v.name)))(du, u, in, p, t)
346+
$(resolved_uip)
347+
end
348+
end
349+
Base.remove_linenums!(functionex)
350+
351+
#no lets get rid of empty blocsk
352+
noblocks = postwalk(functionex) do x
353+
if x isa Expr && x.head == :block && length(x.args) == 1
354+
x.args[1]
355+
else
356+
x
357+
end
358+
end
359+
360+
pretty = sprint(show, MIME"text/plain"(), noblocks)
361+
noquote = replace(pretty, r"^:\(" => "")
362+
noend = replace(noquote, r"\)$" => "")
363+
noindent = replace(noend, r"^ "m => "")
364+
365+
if allinbound
366+
noindent = "@inbounds " * noindent
367+
end
368+
369+
print(noindent)
370+
end

0 commit comments

Comments
 (0)