Skip to content

Commit 025f225

Browse files
committed
Track external registrations of Operations, so they can be injected into generated functions.This solves the issue where an external package may wish to use a function it defines "my_func" as an Operation, and so calls "@register my_func(x)" on it. However, when ModelingToolkit eventually generates a function including that Operation, GeneralizedGenerated renames all calls to the function to "(ModelingToolkit).my_func", which does not exist because that function was not defined in the ModelingToolkit namespace. Instead, we track all the functions that are registered and which module they come from, and use that to re-write the generated functions with the appropriate module.
1 parent 2c09502 commit 025f225

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
1212
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1313
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1414
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
15+
JuliaVariables = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec"
1516
Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
1617
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1718
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -35,6 +36,7 @@ DiffEqJump = "6.7.5"
3536
DiffRules = "0.1, 1.0"
3637
DocStringExtensions = "0.7, 0.8"
3738
GeneralizedGenerated = "0.1.4, 0.2"
39+
JuliaVariables = "0.2.0"
3840
Latexify = "0.11, 0.12, 0.13"
3941
LightGraphs = "1.3"
4042
MacroTools = "0.5"

src/function_registration.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ ModelingToolkit IR. Example:
1010
```
1111
1212
registers `f` as a possible two-argument function.
13+
14+
NOTE: If registering outside of ModelingToolkit (i.e. in your own package),
15+
this should be done at runtime (e.g. in a package `__init__()`, or inside a
16+
method that is called at runtime and not during precompile) to ensure that
17+
any generated functions will use your registered method. See
18+
`inject_registered_module_functions()`.
1319
"""
1420
macro register(sig)
1521
splitsig = splitdef(:($sig = nothing))
@@ -22,6 +28,11 @@ macro register(sig)
2228
splitsig[:body] = :(Operation($name, Expression[$(args...)]))
2329
defs = :($defs; $(combinedef(splitsig)))
2430
end
31+
if (@__MODULE__) != ModelingToolkit
32+
# Register external module registrations so we can rewrite any generated
33+
# functions through JuliaVariables.solve().
34+
get!(ModelingToolkit.registered_external_functions, name, @__MODULE__)
35+
end
2536
esc(defs)
2637
end
2738
# Create all valid combinations of Expression,Number for function signature
@@ -69,3 +80,43 @@ Base.:^(x::Expression,y::T) where T <: Rational = Operation(Base.:^, Expression[
6980

7081
Base.getindex(x::Operation,i::Int64) = Operation(getindex,[x,i])
7182
Base.one(::Operation) = 1
83+
84+
# ---
85+
# Ensure that Operations that get @registered from outside the ModelingToolkit module can work without
86+
# having to bring in the associated function into the ModelingToolkit namespace.
87+
# We basically store information about functions registered at runtime in a ModelingToolkit variable,
88+
# `registered_external_functions`. It's not pretty, but we are limited by the way GeneralizedGenerated
89+
# builds a function (adding "ModelingToolkit" to every function call).
90+
# ---
91+
import JuliaVariables
92+
const registered_external_functions = Dict{Symbol,Module}()
93+
function inject_registered_module_functions(expr)
94+
MacroTools.postwalk(expr) do x
95+
MacroTools.@capture(x, f_(xs__)) # We need to find all function calls in the expression.
96+
# If the function call has been converted to a JuliaVariables.Var and matches
97+
# one of the functions we've registered...
98+
if !isnothing(f) && x.args[1] isa JuliaVariables.Var && x.args[1].name in keys(registered_external_functions)
99+
# Rewrite it from a Var to a regular function call.
100+
x.args[1] = getproperty(registered_external_functions[x.args[1].name], x.args[1].name)
101+
end
102+
return x # Make sure we rebuild the expression as is.
103+
end
104+
end
105+
106+
# TODO: Overwriting this function works, but is quite ugly. Is there a nicer way to inject the module names?
107+
function GeneralizedGenerated.mk_function(mod::Module, ex)
108+
ex = macroexpand(mod, ex)
109+
ex = GeneralizedGenerated.simplify_ex(ex)
110+
ex = GeneralizedGenerated.solve(ex)
111+
112+
# We need to modify the expression built by the JuliaVariables.solve(ex)
113+
# method, before GeneralizedGenerated.closure_conv(mod, ex) converts it to a
114+
# RuntimeFn (as done in GeneralizedGenerated.mk_function(mod, ex)).
115+
ex = inject_registered_module_functions(ex)
116+
117+
fn = GeneralizedGenerated.closure_conv(mod, ex)
118+
if !(fn isa GeneralizedGenerated.RuntimeFn)
119+
error("Expect an unnamed function expression. ")
120+
end
121+
fn
122+
end

0 commit comments

Comments
 (0)