Skip to content

Commit 979df0d

Browse files
working cfunction compilation
1 parent adf373a commit 979df0d

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using SpecialFunctions, NaNMath
1111
using Base.Threads
1212
import MacroTools: splitdef, combinedef, postwalk, striplines
1313
import GeneralizedGenerated
14+
import Libdl
1415
using DocStringExtensions
1516
using Base: RefValue
1617

@@ -156,6 +157,7 @@ export build_function
156157
export @register
157158
export modelingtoolkitize
158159
export @variables, @parameters
160+
export compile_cfunction
159161

160162
const HAS_DAGGER = Ref{Bool}(false)
161163
function __init__()

src/build_function.jl

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -430,38 +430,38 @@ function numbered_expr(O::Equation,args...;kwargs...)
430430
:($(numbered_expr(O.lhs,args...;kwargs...)) = $(numbered_expr(O.rhs,args...;kwargs...)))
431431
end
432432

433-
function numbered_expr(O::Operation,vars,parameters;
433+
function numbered_expr(O::Operation,vars,parameters;offset = 0,
434434
derivname=:du,
435435
varname=:u,paramname=:p)
436436
if isa(O.op, ModelingToolkit.Differential)
437437
varop = O.args[1]
438438
i = get_varnumber(varop,vars)
439-
return :($derivname[$i])
439+
return :($derivname[$(i+offset)])
440440
elseif isa(O.op, ModelingToolkit.Variable)
441441
i = get_varnumber(O,vars)
442442
if i == nothing
443443
i = get_varnumber(O,parameters)
444-
return :($paramname[$i])
444+
return :($paramname[$(i+offset)])
445445
else
446-
return :($varname[$i])
446+
return :($varname[$(i+offset)])
447447
end
448448
end
449449
return Expr(:call, Symbol(O.op),
450-
[numbered_expr(x,vars,parameters;derivname=derivname,
450+
[numbered_expr(x,vars,parameters;offset=offset,derivname=derivname,
451451
varname=varname,paramname=paramname) for x in O.args]...)
452452
end
453453

454454
function numbered_expr(de::ModelingToolkit.Equation,vars::Vector{<:Variable},parameters;
455-
derivname=:du,varname=:u,paramname=:p)
455+
derivname=:du,varname=:u,paramname=:p,offset=0)
456456
i = findfirst(x->isequal(x.name,var_from_nested_derivative(de.lhs)[1].name),vars)
457-
:($derivname[$i] = $(numbered_expr(de.rhs,vars,parameters;
457+
:($derivname[$(i+offset)] = $(numbered_expr(de.rhs,vars,parameters;offset=offset,
458458
derivname=derivname,
459459
varname=varname,paramname=paramname)))
460460
end
461461
function numbered_expr(de::ModelingToolkit.Equation,vars::Vector{Operation},parameters;
462-
derivname=:du,varname=:u,paramname=:p)
462+
derivname=:du,varname=:u,paramname=:p,offset=0)
463463
i = findfirst(x->isequal(x.op.name,var_from_nested_derivative(de.lhs)[1].name),vars)
464-
:($derivname[$i] = $(numbered_expr(de.rhs,vars,parameters;
464+
:($derivname[$(i+offset)] = $(numbered_expr(de.rhs,vars,parameters;offset=offset,
465465
derivname=derivname,
466466
varname=varname,paramname=paramname)))
467467
end
@@ -488,10 +488,10 @@ function _build_function(target::CTarget, eqs, vs, ps, iv;
488488
fname = :diffeqf, derivname=:internal_var___du,
489489
varname=:internal_var___u,paramname=:internal_var___p)
490490
differential_equation = string(join([numbered_expr(eq,vs,ps,derivname=derivname,
491-
varname=varname,paramname=paramname) for
491+
varname=varname,paramname=paramname,offset=-1) for
492492
(i, eq) enumerate(eqs)],";\n "),";")
493493
"""
494-
void $fname(double* $derivname, double* $varname, double* $paramname, $iv) {
494+
void $fname(double* $derivname, double* $varname, double* $paramname, double $iv) {
495495
$differential_equation
496496
}
497497
"""
@@ -510,3 +510,22 @@ function _build_function(target::MATLABTarget, eqs, vs, ps, iv;
510510
matstr = "$fname = @(t,$varname) ["*matstr*"];"
511511
matstr
512512
end
513+
514+
"""
515+
compile_cfunction(eqs,args...;libpath=tempname(),compiler=:gcc)
516+
517+
Builds a function in C, compiles it, and returns a lambda to that compiled function.
518+
Arguments match those of `build_function`. Keyword arguments:
519+
520+
- libpath: the path to store the binary. Defaults to a temporary path.
521+
- compiler: which C compiler to use. Defaults to :gcc, which is currently the
522+
only available option.
523+
"""
524+
function compile_cfunction(eqs,args...;libpath=tempname(),compiler=:gcc)
525+
@assert compiler == :gcc
526+
ex = build_function(eqs,args...;target=ModelingToolkit.CTarget())
527+
open(`gcc -fPIC -O3 -msse3 -xc -shared -o $(libpath * "." * Libdl.dlext) -`, "w") do f
528+
print(f, ex)
529+
end
530+
eval(:((du::Array{Float64},u::Array{Float64},p::Array{Float64},t::Float64) -> ccall(("diffeqf", $libpath), Cvoid, (Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Float64), du, u, p, t)))
531+
end

0 commit comments

Comments
 (0)