Skip to content

Commit 807169c

Browse files
refactor: modularize code and enable array hacks
1 parent 55c4593 commit 807169c

File tree

1 file changed

+53
-54
lines changed

1 file changed

+53
-54
lines changed

ext/MTKFMIExt.jl

Lines changed: 53 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -36,68 +36,43 @@ function MTK.FMIComponent(::Val{2}, ::Val{:ME}; fmu = nothing, tolerance = 1e-6,
3636
states = []
3737
diffvars = []
3838
observed = Equation[]
39-
stateT = Float64
40-
for (valRef, snames) in FMI.getStateValueReferencesAndNames(fmu)
41-
stateT = FMI.dataTypeForValueReference(fmu, valRef)
42-
snames = map(parseFMIVariableName, snames)
43-
vars = [MTK.unwrap(only(@variables $sname(t)::stateT)) for sname in snames]
44-
for i in eachindex(vars)
45-
if i == 1
46-
push!(diffvars, vars[i])
47-
else
48-
push!(observed, vars[i] ~ vars[1])
49-
end
50-
value_references[vars[i]] = valRef
51-
end
52-
append!(states, vars)
39+
fmi_variables_to_mtk_variables!(fmu, FMI.getStateValueReferencesAndNames(fmu),
40+
value_references, diffvars, states, observed)
41+
if isempty(diffvars)
42+
__mtk_internal_u = []
43+
else
44+
@variables __mtk_internal_u(t)[1:length(diffvars)]
45+
push!(observed, __mtk_internal_u ~ copy(diffvars))
46+
push!(states, __mtk_internal_u)
5347
end
5448

5549
inputs = []
56-
for (valRef, snames) in FMI.getInputValueReferencesAndNames(fmu)
57-
snames = map(parseFMIVariableName, snames)
58-
vars = [MTK.unwrap(only(@variables $sname(t)::stateT)) for sname in snames]
59-
for i in eachindex(vars)
60-
if i == 1
61-
push!(inputs, vars[i])
62-
else
63-
push!(observed, vars[i] ~ vars[1])
64-
end
65-
value_references[vars[i]] = valRef
66-
end
67-
append!(states, vars)
50+
fmi_variables_to_mtk_variables!(fmu, FMI.getInputValueReferencesAndNames(fmu),
51+
value_references, inputs, states, observed)
52+
if isempty(inputs)
53+
__mtk_internal_x = []
54+
else
55+
@variables __mtk_internal_x(t)[1:length(inputs)]
56+
push!(observed, __mtk_internal_x ~ copy(inputs))
57+
push!(states, __mtk_internal_x)
6858
end
6959

7060
outputs = []
71-
for (valRef, snames) in FMI.getOutputValueReferencesAndNames(fmu)
72-
snames = map(parseFMIVariableName, snames)
73-
vars = [MTK.unwrap(only(@variables $sname(t)::stateT)) for sname in snames]
74-
for i in eachindex(vars)
75-
if i == 1
76-
push!(outputs, vars[i])
77-
else
78-
push!(observed, vars[i] ~ vars[1])
79-
end
80-
value_references[vars[i]] = valRef
81-
end
82-
append!(states, vars)
83-
end
61+
fmi_variables_to_mtk_variables!(fmu, FMI.getOutputValueReferencesAndNames(fmu),
62+
value_references, outputs, states, observed)
63+
# @variables __mtk_internal_o(t)[1:length(outputs)]
64+
# push!(observed, __mtk_internal_o ~ outputs)
8465

8566
params = []
8667
parameter_dependencies = Equation[]
87-
for (valRef, pnames) in FMI.getParameterValueReferencesAndNames(fmu)
88-
defval = FMI.getStartValue(fmu, valRef)
89-
T = FMI.dataTypeForValueReference(fmu, valRef)
90-
pnames = map(parseFMIVariableName, pnames)
91-
vars = [MTK.unwrap(only(@parameters $pname::T)) for pname in pnames]
92-
for i in eachindex(vars)
93-
if i == 1
94-
push!(params, vars[i])
95-
else
96-
push!(parameter_dependencies, vars[i] ~ vars[1])
97-
end
98-
value_references[vars[i]] = valRef
99-
end
100-
defs[vars[1]] = defval
68+
fmi_variables_to_mtk_variables!(
69+
fmu, FMI.getParameterValueReferencesAndNames(fmu), value_references,
70+
params, [], parameter_dependencies, defs; parameters = true)
71+
if isempty(params)
72+
__mtk_internal_p = []
73+
else
74+
@parameters __mtk_internal_p[1:length(params)]
75+
push!(parameter_dependencies, __mtk_internal_p ~ copy(params))
10176
end
10277

10378
input_value_references = UInt32[value_references[var] for var in inputs]
@@ -109,7 +84,7 @@ function MTK.FMIComponent(::Val{2}, ::Val{:ME}; fmu = nothing, tolerance = 1e-6,
10984
buffer_length = length(diffvars) + length(outputs)
11085
_functor = FMI2MEFunctor(zeros(buffer_length), output_value_references)
11186
@parameters (functor::(typeof(_functor)))(..)[1:buffer_length] = _functor
112-
call_expr = functor(wrapper, copy(diffvars), copy(inputs), copy(params), t)
87+
call_expr = functor(wrapper, __mtk_internal_u, __mtk_internal_x, __mtk_internal_p, t)
11388

11489
diffeqs = Equation[]
11590
for (i, var) in enumerate([D.(diffvars); outputs])
@@ -127,6 +102,30 @@ function MTK.FMIComponent(::Val{2}, ::Val{:ME}; fmu = nothing, tolerance = 1e-6,
127102
discrete_events = [instance_management_callback], name)
128103
end
129104

105+
function fmi_variables_to_mtk_variables!(fmu, varmap, value_references, truevars, allvars,
106+
obseqs, defs = Dict(); parameters = false)
107+
for (valRef, snames) in varmap
108+
stateT = FMI.dataTypeForValueReference(fmu, valRef)
109+
snames = map(parseFMIVariableName, snames)
110+
if parameters
111+
vars = [MTK.unwrap(only(@parameters $sname::stateT)) for sname in snames]
112+
else
113+
vars = [MTK.unwrap(only(@variables $sname(t)::stateT)) for sname in snames]
114+
end
115+
for i in eachindex(vars)
116+
if i == 1
117+
push!(truevars, vars[i])
118+
else
119+
push!(obseqs, vars[i] ~ vars[1])
120+
end
121+
value_references[vars[i]] = valRef
122+
end
123+
append!(allvars, vars)
124+
defval = FMI.getStartValue(fmu, valRef)
125+
defs[vars[1]] = defval
126+
end
127+
end
128+
130129
function parseFMIVariableName(name::AbstractString)
131130
return Symbol(replace(name, "." => "__"))
132131
end

0 commit comments

Comments
 (0)