Skip to content

Commit 4c394c5

Browse files
feat: support building v2 Co-Simulation FMU components
1 parent 00cc075 commit 4c394c5

File tree

1 file changed

+191
-47
lines changed

1 file changed

+191
-47
lines changed

ext/MTKFMIExt.jl

Lines changed: 191 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module MTKFMIExt
22

33
using ModelingToolkit
4+
using SymbolicIndexingInterface
45
using ModelingToolkit: t_nounits as t, D_nounits as D
56
import ModelingToolkit as MTK
67
import FMI
@@ -17,20 +18,24 @@ macro statuscheck(expr)
1718
return quote
1819
status = $expr
1920
fnname = $fnname
20-
if (status isa Tuple && status[1] == FMI.fmi2True) ||
21-
(!(status isa Tuple) && status != FMI.fmi2StatusOK &&
22-
status != FMI.fmi2StatusWarning)
21+
if status !== nothing && ((status isa Tuple && status[1] == FMI.fmi2True) ||
22+
(!(status isa Tuple) && status != FMI.fmi2StatusOK &&
23+
status != FMI.fmi2StatusWarning))
2324
if status != FMI.fmi2StatusFatal
2425
FMI.fmi2Terminate(wrapper.instance)
2526
end
2627
FMI.fmi2FreeInstance!(wrapper.instance)
2728
wrapper.instance = nothing
28-
error("FMU Error: status $status")
29+
error("FMU Error in $fnname: status $status")
2930
end
3031
end |> esc
3132
end
3233

33-
function MTK.FMIComponent(::Val{2}, ::Val{:ME}; fmu = nothing, tolerance = 1e-6, name)
34+
function MTK.FMIComponent(::Val{2}; fmu = nothing, tolerance = 1e-6,
35+
communication_step_size = nothing, type, name)
36+
if type == :CS && communication_step_size === nothing
37+
throw(ArgumentError("`communication_step_size` must be specified for Co-Simulation FMUs."))
38+
end
3439
value_references = Dict()
3540
defs = Dict()
3641
states = []
@@ -40,10 +45,12 @@ function MTK.FMIComponent(::Val{2}, ::Val{:ME}; fmu = nothing, tolerance = 1e-6,
4045
value_references, diffvars, states, observed)
4146
if isempty(diffvars)
4247
__mtk_internal_u = []
43-
else
44-
@variables __mtk_internal_u(t)[1:length(diffvars)]
48+
elseif type == :ME
49+
@variables __mtk_internal_u(t)[1:length(diffvars)] [guess = diffvars]
50+
push!(observed, __mtk_internal_u ~ copy(diffvars))
51+
elseif type == :CS
52+
@parameters __mtk_internal_u(t)[1:length(diffvars)]=missing [guess = diffvars]
4553
push!(observed, __mtk_internal_u ~ copy(diffvars))
46-
push!(states, __mtk_internal_u)
4754
end
4855

4956
inputs = []
@@ -52,16 +59,22 @@ function MTK.FMIComponent(::Val{2}, ::Val{:ME}; fmu = nothing, tolerance = 1e-6,
5259
if isempty(inputs)
5360
__mtk_internal_x = []
5461
else
55-
@variables __mtk_internal_x(t)[1:length(inputs)]
62+
@variables __mtk_internal_x(t)[1:length(inputs)] [guess = inputs]
5663
push!(observed, __mtk_internal_x ~ copy(inputs))
5764
push!(states, __mtk_internal_x)
5865
end
5966

6067
outputs = []
6168
fmi_variables_to_mtk_variables!(fmu, FMI.getOutputValueReferencesAndNames(fmu),
6269
value_references, outputs, states, observed)
63-
# @variables __mtk_internal_o(t)[1:length(outputs)]
64-
# push!(observed, __mtk_internal_o ~ outputs)
70+
if type == :CS
71+
if isempty(outputs)
72+
__mtk_internal_o = []
73+
else
74+
@parameters __mtk_internal_o(t)[1:length(outputs)]=missing [guess = zeros(length(outputs))]
75+
push!(observed, __mtk_internal_o ~ outputs)
76+
end
77+
end
6578

6679
params = []
6780
parameter_dependencies = Equation[]
@@ -82,24 +95,69 @@ function MTK.FMIComponent(::Val{2}, ::Val{:ME}; fmu = nothing, tolerance = 1e-6,
8295

8396
output_value_references = UInt32[value_references[var] for var in outputs]
8497
buffer_length = length(diffvars) + length(outputs)
85-
_functor = FMI2MEFunctor(zeros(buffer_length), output_value_references)
86-
@parameters (functor::(typeof(_functor)))(..)[1:buffer_length] = _functor
87-
call_expr = functor(wrapper, __mtk_internal_u, __mtk_internal_x, __mtk_internal_p, t)
8898

89-
diffeqs = Equation[]
90-
for (i, var) in enumerate([D.(diffvars); outputs])
91-
push!(diffeqs, var ~ call_expr[i])
92-
end
99+
initialization_eqs = Equation[]
100+
101+
if type == :ME
102+
_functor = FMI2MEFunctor(zeros(buffer_length), output_value_references)
103+
@parameters (functor::(typeof(_functor)))(..)[1:buffer_length] = _functor
104+
call_expr = functor(
105+
wrapper, __mtk_internal_u, __mtk_internal_x, __mtk_internal_p, t)
106+
107+
diffeqs = Equation[]
108+
for (i, var) in enumerate([D.(diffvars); outputs])
109+
push!(diffeqs, var ~ call_expr[i])
110+
end
111+
112+
finalize_affect = MTK.FunctionalAffect(fmi2Finalize!, [], [wrapper], [])
113+
step_affect = MTK.FunctionalAffect(fmi2MEStep!, [], [wrapper], [])
114+
instance_management_callback = MTK.SymbolicDiscreteCallback(
115+
(t != t - 1), step_affect; finalize = finalize_affect)
116+
117+
push!(params, wrapper, functor)
118+
push!(states, __mtk_internal_u)
119+
elseif type == :CS
120+
state_value_references = UInt32[value_references[var] for var in diffvars]
121+
state_and_output_value_references = vcat(
122+
state_value_references, output_value_references)
123+
_functor = FMI2CSFunctor(state_and_output_value_references,
124+
state_value_references, output_value_references)
125+
@parameters (functor::(typeof(_functor)))(..)[1:(length(__mtk_internal_u) + length(__mtk_internal_o))] = _functor
126+
for (i, x) in enumerate(collect(__mtk_internal_o))
127+
push!(initialization_eqs,
128+
x ~ functor(
129+
wrapper, __mtk_internal_u, __mtk_internal_x, __mtk_internal_p, t)[i])
130+
end
93131

94-
finalize_affect = MTK.FunctionalAffect(fmi2MEFinalize!, [], [wrapper], [])
95-
step_affect = MTK.FunctionalAffect(fmi2MEStep!, [], [wrapper], [])
96-
instance_management_callback = MTK.SymbolicDiscreteCallback(
97-
(t != t - 1), step_affect; finalize = finalize_affect)
132+
diffeqs = Equation[]
133+
134+
cb_observed = (; inputs = __mtk_internal_x, params = copy(params),
135+
t, wrapper, dt = communication_step_size)
136+
cb_modified = (;)
137+
if symbolic_type(__mtk_internal_o) != NotSymbolic()
138+
cb_modified = (cb_modified..., outputs = __mtk_internal_o)
139+
end
140+
if symbolic_type(__mtk_internal_u) != NotSymbolic()
141+
cb_modified = (cb_modified..., states = __mtk_internal_u)
142+
end
143+
initialize_affect = MTK.ImperativeAffect(fmi2CSInitialize!; observed = cb_observed,
144+
modified = cb_modified, ctx = _functor)
145+
finalize_affect = MTK.FunctionalAffect(fmi2Finalize!, [], [wrapper], [])
146+
step_affect = MTK.ImperativeAffect(
147+
fmi2CSStep!; observed = cb_observed, modified = cb_modified, ctx = _functor)
148+
instance_management_callback = MTK.SymbolicDiscreteCallback(
149+
communication_step_size, step_affect; initialize = initialize_affect, finalize = finalize_affect
150+
)
151+
152+
symbolic_type(__mtk_internal_o) == NotSymbolic() || push!(params, __mtk_internal_o)
153+
symbolic_type(__mtk_internal_u) == NotSymbolic() || push!(params, __mtk_internal_u)
154+
155+
push!(params, wrapper, functor)
156+
end
98157

99-
push!(params, wrapper, functor)
100158
eqs = [observed; diffeqs]
101159
return ODESystem(eqs, t, states, params; parameter_dependencies, defaults = defs,
102-
discrete_events = [instance_management_callback], name)
160+
discrete_events = [instance_management_callback], name, initialization_eqs)
103161
end
104162

105163
function fmi_variables_to_mtk_variables!(fmu, varmap, value_references, truevars, allvars,
@@ -142,35 +200,42 @@ function FMI2InstanceWrapper(fmu, params, inputs, tolerance)
142200
FMI2InstanceWrapper(fmu, params, inputs, tolerance, nothing)
143201
end
144202

145-
function get_instance!(wrapper::FMI2InstanceWrapper, states, inputs, params, t)
203+
function get_instance_common!(wrapper::FMI2InstanceWrapper, states, inputs, params, t)
204+
wrapper.instance = FMI.fmi2Instantiate!(wrapper.fmu)::FMI.FMU2Component
205+
if !isempty(params)
206+
@statuscheck FMI.fmi2SetReal(wrapper.instance, wrapper.param_value_references,
207+
Csize_t(length(wrapper.param_value_references)), params)
208+
end
209+
@statuscheck FMI.fmi2SetupExperiment(
210+
wrapper.instance, FMI.fmi2True, wrapper.tolerance, t, FMI.fmi2False, t)
211+
@statuscheck FMI.fmi2EnterInitializationMode(wrapper.instance)
212+
if !isempty(inputs)
213+
@statuscheck FMI.fmi2SetReal(wrapper.instance, wrapper.input_value_references,
214+
Csize_t(length(wrapper.param_value_references)), inputs)
215+
end
216+
217+
return wrapper.instance
218+
end
219+
220+
function get_instance_ME!(wrapper::FMI2InstanceWrapper, states, inputs, params, t)
146221
if wrapper.instance === nothing
147-
wrapper.instance = FMI.fmi2Instantiate!(wrapper.fmu)::FMI.FMU2Component
148-
if !isempty(params)
149-
@statuscheck FMI.fmi2SetReal(wrapper.instance, wrapper.param_value_references,
150-
Csize_t(length(wrapper.param_value_references)), params)
151-
end
152-
@statuscheck FMI.fmi2SetupExperiment(
153-
wrapper.instance, FMI.fmi2True, wrapper.tolerance, t, FMI.fmi2False, t)
154-
@statuscheck FMI.fmi2EnterInitializationMode(wrapper.instance)
155-
if !isempty(inputs)
156-
@statuscheck FMI.fmi2SetReal(wrapper.instance, wrapper.input_value_references,
157-
Csize_t(length(wrapper.param_value_references)), inputs)
158-
end
222+
get_instance_common!(wrapper, states, inputs, params, t)
159223
@statuscheck FMI.fmi2ExitInitializationMode(wrapper.instance)
160224
eventInfo = FMI.fmi2NewDiscreteStates(wrapper.instance)
161225
@assert eventInfo.newDiscreteStatesNeeded == FMI.fmi2False
162226
# TODO: Support FMU events
163227
@statuscheck FMI.fmi2EnterContinuousTimeMode(wrapper.instance)
164228
end
165-
instance = wrapper.instance
166-
@statuscheck FMI.fmi2SetTime(instance, t)
167-
@statuscheck FMI.fmi2SetContinuousStates(instance, states)
168-
if !isempty(inputs)
169-
@statuscheck FMI.fmi2SetReal(instance, wrapper.input_value_references,
170-
Csize_t(length(wrapper.param_value_references)), inputs)
171-
end
172229

173-
return instance
230+
return wrapper.instance
231+
end
232+
233+
function get_instance_CS!(wrapper::FMI2InstanceWrapper, states, inputs, params, t)
234+
if wrapper.instance === nothing
235+
get_instance_common!(wrapper, states, inputs, params, t)
236+
@statuscheck FMI.fmi2ExitInitializationMode(wrapper.instance)
237+
end
238+
return wrapper.instance
174239
end
175240

176241
function complete_step!(wrapper::FMI2InstanceWrapper)
@@ -198,8 +263,19 @@ end
198263
ndims = 1
199264
end
200265

266+
function update_instance_ME!(wrapper::FMI2InstanceWrapper, states, inputs, t)
267+
instance = wrapper.instance
268+
@statuscheck FMI.fmi2SetTime(instance, t)
269+
@statuscheck FMI.fmi2SetContinuousStates(instance, states)
270+
if !isempty(inputs)
271+
@statuscheck FMI.fmi2SetReal(instance, wrapper.input_value_references,
272+
Csize_t(length(wrapper.param_value_references)), inputs)
273+
end
274+
end
275+
201276
function (fn::FMI2MEFunctor)(wrapper::FMI2InstanceWrapper, states, inputs, params, t)
202-
instance = get_instance!(wrapper, states, inputs, params, t)
277+
instance = get_instance_ME!(wrapper, states, inputs, params, t)
278+
update_instance_ME!(wrapper, states, inputs, t)
203279

204280
states_buffer = zeros(length(states))
205281
@statuscheck FMI.fmi2GetDerivatives!(instance, states_buffer)
@@ -214,10 +290,78 @@ function fmi2MEStep!(integrator, u, p, ctx)
214290
complete_step!(wrapper)
215291
end
216292

217-
function fmi2MEFinalize!(integrator, u, p, ctx)
293+
function fmi2Finalize!(integrator, u, p, ctx)
218294
wrapper_idx = p[1]
219295
wrapper = integrator.ps[wrapper_idx]
220296
reset_instance!(wrapper)
221297
end
222298

299+
struct FMI2CSFunctor
300+
state_and_output_value_references::Vector{UInt32}
301+
state_value_references::Vector{UInt32}
302+
output_value_references::Vector{UInt32}
303+
end
304+
305+
function (fn::FMI2CSFunctor)(wrapper::FMI2InstanceWrapper, states, inputs, params, t)
306+
states = states isa SubArray ? copy(states) : states
307+
inputs = inputs isa SubArray ? copy(inputs) : inputs
308+
params = params isa SubArray ? copy(params) : params
309+
instance = get_instance_CS!(wrapper, states, inputs, params, t)
310+
if isempty(fn.output_value_references)
311+
return eltype(states)[]
312+
else
313+
return FMI.fmi2GetReal(instance, fn.output_value_references)
314+
end
315+
end
316+
317+
@register_array_symbolic (fn::FMI2CSFunctor)(
318+
wrapper::FMI2InstanceWrapper, states::Vector{<:Real},
319+
inputs::Vector{<:Real}, params::Vector{<:Real}, t::Real) begin
320+
size = (length(states) + length(fn.output_value_references),)
321+
eltype = eltype(states)
322+
ndims = 1
323+
end
324+
325+
function fmi2CSInitialize!(m, o, ctx::FMI2CSFunctor, integrator)
326+
states = isdefined(m, :states) ? m.states : ()
327+
inputs = o.inputs
328+
params = o.params
329+
t = o.t
330+
wrapper = o.wrapper
331+
if wrapper.instance !== nothing
332+
reset_instance!(wrapper)
333+
end
334+
instance = get_instance_common!(wrapper, states, inputs, params, t)
335+
@statuscheck FMI.fmi2ExitInitializationMode(instance)
336+
if isdefined(m, :states)
337+
@statuscheck FMI.fmi2GetReal!(instance, ctx.state_value_references, m.states)
338+
end
339+
if isdefined(m, :outputs)
340+
@statuscheck FMI.fmi2GetReal!(instance, ctx.output_value_references, m.outputs)
341+
end
342+
343+
return m
344+
end
345+
346+
function fmi2CSStep!(m, o, ctx::FMI2CSFunctor, integrator)
347+
wrapper = o.wrapper
348+
states = isdefined(m, :states) ? m.states : ()
349+
inputs = o.inputs
350+
params = o.params
351+
t = o.t
352+
dt = o.dt
353+
354+
instance = get_instance_CS!(wrapper, states, inputs, params, integrator.t)
355+
@statuscheck FMI.fmi2DoStep(instance, integrator.t - dt, dt, FMI.fmi2True)
356+
357+
if isdefined(m, :states)
358+
@statuscheck FMI.fmi2GetReal!(instance, ctx.state_value_references, m.states)
359+
end
360+
if isdefined(m, :outputs)
361+
@statuscheck FMI.fmi2GetReal!(instance, ctx.output_value_references, m.outputs)
362+
end
363+
364+
return m
365+
end
366+
223367
end # module

0 commit comments

Comments
 (0)