Skip to content

Commit fc98fc4

Browse files
Merge pull request #445 from dpad/track-external-registrations
Track external registrations of Operations, so they can be injected into generated functions.
2 parents 2c09502 + 33b9173 commit fc98fc4

File tree

4 files changed

+124
-3
lines changed

4 files changed

+124
-3
lines changed

src/build_function.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,26 @@ function _build_function(target::JuliaTarget, op::Operation, args...;
127127
if expression == Val{true}
128128
return oop_ex
129129
else
130-
return GeneralizedGenerated.mk_function(@__MODULE__,oop_ex)
130+
_build_and_inject_function(@__MODULE__, oop_ex)
131131
end
132132
end
133133

134+
function _build_and_inject_function(mod::Module, ex)
135+
# Generate the function, which will process the expression
136+
runtimefn = GeneralizedGenerated.mk_function(@__MODULE__, ex)
137+
138+
# Extract the processed expression of the function body
139+
params = typeof(runtimefn).parameters
140+
fn_expr = GeneralizedGenerated.NGG.from_type(params[3])
141+
142+
# Inject our externally registered module functions
143+
new_expr = ModelingToolkit.inject_registered_module_functions(fn_expr)
144+
145+
# Reconstruct the RuntimeFn's Body
146+
new_body = GeneralizedGenerated.NGG.to_type(new_expr)
147+
return GeneralizedGenerated.RuntimeFn{params[1:2]..., new_body, params[4]}()
148+
end
149+
134150
# Detect heterogeneous element types of "arrays of matrices/sparce matrices"
135151
function is_array_matrix(F)
136152
return isa(F, AbstractVector) && all(x->isa(x, AbstractArray), F)
@@ -288,7 +304,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
288304
if expression == Val{true}
289305
return oop_ex, iip_ex
290306
else
291-
return GeneralizedGenerated.mk_function(@__MODULE__,oop_ex), GeneralizedGenerated.mk_function(@__MODULE__,iip_ex)
307+
return _build_and_inject_function(@__MODULE__, oop_ex), _build_and_inject_function(@__MODULE__, iip_ex)
292308
end
293309
end
294310

src/function_registration.jl

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,30 @@ registers `f` as a possible two-argument function.
1414
macro register(sig)
1515
splitsig = splitdef(:($sig = nothing))
1616
name = splitsig[:name]
17+
18+
# Extract the module and function name from the signature
19+
if name isa Symbol
20+
mod = __module__ # Calling module
21+
funcname = name
22+
else
23+
mod = name.args[1]
24+
funcname = name.args[2].value
25+
end
26+
1727
args = splitsig[:args]
1828
typargs = typed_args(args)
1929
defs = :()
2030
for typarg in typargs
2131
splitsig[:args] = typarg
22-
splitsig[:body] = :(Operation($name, Expression[$(args...)]))
32+
if mod == (@__MODULE__) # If the calling module is ModelingToolkit itself...
33+
splitsig[:body] = :(Operation($name, Expression[$(args...)]))
34+
else
35+
# Register the function's associated model so we can inject it in later.
36+
splitsig[:body] = quote
37+
get!(ModelingToolkit.registered_external_functions, Symbol($("$funcname")), $mod)
38+
Operation($name, Expression[$(args...)])
39+
end
40+
end
2341
defs = :($defs; $(combinedef(splitsig)))
2442
end
2543
esc(defs)
@@ -69,3 +87,30 @@ Base.:^(x::Expression,y::T) where T <: Rational = Operation(Base.:^, Expression[
6987

7088
Base.getindex(x::Operation,i::Int64) = Operation(getindex,[x,i])
7189
Base.one(::Operation) = 1
90+
91+
# Ensure that Operations that get @registered from outside the ModelingToolkit
92+
# module can work without having to bring in the associated function into the
93+
# ModelingToolkit namespace. We basically store information about functions
94+
# registered at runtime in a ModelingToolkit variable,
95+
# `registered_external_functions`. It's not pretty, but we are limited by the
96+
# way GeneralizedGenerated builds a function (adding "ModelingToolkit" to every
97+
# function call).
98+
# ---
99+
const registered_external_functions = Dict{Symbol,Module}()
100+
function inject_registered_module_functions(expr)
101+
MacroTools.postwalk(expr) do x
102+
# We need to find all function calls in the expression...
103+
MacroTools.@capture(x, f_(xs__))
104+
105+
if !isnothing(f) && f isa Expr && f.head == :. && f.args[2] isa QuoteNode
106+
# If the function call matches any of the functions we've
107+
# registered, set the calling module (which is probably
108+
# "ModelingToolkit") to the module it is registered to.
109+
f_name = f.args[2].value # function name
110+
f.args[1] = get(registered_external_functions, f_name, f.args[1])
111+
end
112+
113+
# Make sure we rebuild the expression as is.
114+
return x
115+
end
116+
end

test/function_registration.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# TEST: Function registration in a module.
2+
# ------------------------------------------------
3+
module MyModule
4+
using ModelingToolkit, DiffEqBase, LinearAlgebra, Test
5+
@parameters t x
6+
@variables u(t)
7+
@derivatives Dt'~t
8+
9+
function do_something(a)
10+
a + 10
11+
end
12+
@register do_something(a)
13+
14+
eq = Dt(u) ~ do_something(x)
15+
sys = ODESystem([eq], t, [u], [x])
16+
fun = ODEFunction(sys)
17+
18+
@test fun([0.5], [5.0], 0.) == [15.0]
19+
end
20+
21+
# TEST: Function registration in a nested module.
22+
# ------------------------------------------------
23+
module MyModule2
24+
module MyNestedModule
25+
using ModelingToolkit, DiffEqBase, LinearAlgebra, Test
26+
@parameters t x
27+
@variables u(t)
28+
@derivatives Dt'~t
29+
30+
function do_something_2(a)
31+
a + 20
32+
end
33+
@register do_something_2(a)
34+
35+
eq = Dt(u) ~ do_something_2(x)
36+
sys = ODESystem([eq], t, [u], [x])
37+
fun = ODEFunction(sys)
38+
39+
@test fun([0.5], [3.0], 0.) == [23.0]
40+
end
41+
end
42+
43+
# TEST: Function registration outside any modules.
44+
# ------------------------------------------------
45+
using ModelingToolkit, DiffEqBase, LinearAlgebra, Test
46+
@parameters t x
47+
@variables u(t)
48+
@derivatives Dt'~t
49+
50+
function do_something_3(a)
51+
a + 30
52+
end
53+
@register do_something_3(a)
54+
55+
eq = Dt(u) ~ do_something_3(x)
56+
sys = ODESystem([eq], t, [u], [x])
57+
fun = ODEFunction(sys)
58+
59+
@test fun([0.5], [7.0], 0.) == [37.0]

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ using SafeTestsets, Test
2020
@safetestset "Lowering Integration Test" begin include("lowering_solving.jl") end
2121
@safetestset "Test Big System Usage" begin include("bigsystem.jl") end
2222
@safetestset "Depdendency Graph Test" begin include("dep_graphs.jl") end
23+
@safetestset "Function Registration Test" begin include("function_registration.jl") end
2324
#@testset "Latexify recipes Test" begin include("latexify.jl") end
2425
@testset "Distributed Test" begin include("distributed.jl") end
2526
@testset "Array of Array Test" begin include("build_function_arrayofarray.jl") end

0 commit comments

Comments
 (0)