11symconvert (:: Type{Symbolics.Struct{T}} , x) where {T} = convert (T, x)
22symconvert (:: Type{T} , x) where {T} = convert (T, x)
3+ symconvert (:: Type{Real} , x:: Integer ) = convert (Float64, x)
4+ symconvert (:: Type{V} , x) where {V <: AbstractArray } = convert (V, symconvert .(eltype (V), x))
5+
36struct MTKParameters{T, D, C, E, N, F, G}
47 tunable:: T
58 discrete:: D
@@ -67,7 +70,7 @@ function MTKParameters(
6770 end
6871 end
6972
70- isempty (missing_params) || throw (MissingVariablesError (collect (missing_params)))
73+ isempty (missing_params) || throw (MissingParametersError (collect (missing_params)))
7174
7275 tunable_buffer = Tuple (Vector {temp.type} (undef, temp. length)
7376 for temp in ic. tunable_buffer_sizes)
@@ -107,7 +110,7 @@ function MTKParameters(
107110 for (sym, val) in p
108111 sym = unwrap (sym)
109112 val = unwrap (val)
110- ctype = concrete_symtype (sym)
113+ ctype = symtype (sym)
111114 if symbolic_type (val) != = NotSymbolic ()
112115 continue
113116 end
@@ -126,19 +129,27 @@ function MTKParameters(
126129 end
127130 end
128131 end
132+ tunable_buffer = narrow_buffer_type .(tunable_buffer)
133+ disc_buffer = narrow_buffer_type .(disc_buffer)
134+ const_buffer = narrow_buffer_type .(const_buffer)
135+ nonnumeric_buffer = narrow_buffer_type .(nonnumeric_buffer)
129136
130137 if has_parameter_dependencies (sys) &&
131138 (pdeps = get_parameter_dependencies (sys)) != = nothing
132139 pdeps = Dict (k => fixpoint_sub (v, pdeps) for (k, v) in pdeps)
133- dep_exprs = ArrayPartition ((wrap . (v) for v in dep_buffer). .. )
140+ dep_exprs = ArrayPartition ((Any[ missing for _ in 1 : length (v)] for v in dep_buffer). .. )
134141 for (sym, val) in pdeps
135142 i, j = ic. dependent_idx[sym]
136143 dep_exprs. x[i][j] = wrap (val)
137144 end
145+ dep_exprs = identity .(dep_exprs)
138146 p = reorder_parameters (ic, full_parameters (sys))
139147 oop, iip = build_function (dep_exprs, p... )
140148 update_function_iip, update_function_oop = RuntimeGeneratedFunctions. @RuntimeGeneratedFunction (iip),
141149 RuntimeGeneratedFunctions. @RuntimeGeneratedFunction (oop)
150+ update_function_iip (ArrayPartition (dep_buffer), tunable_buffer... , disc_buffer... ,
151+ const_buffer... , nonnumeric_buffer... , dep_buffer... )
152+ dep_buffer = narrow_buffer_type .(dep_buffer)
142153 else
143154 update_function_iip = update_function_oop = nothing
144155 end
@@ -148,12 +159,26 @@ function MTKParameters(
148159 typeof (dep_buffer), typeof (nonnumeric_buffer), typeof (update_function_iip),
149160 typeof (update_function_oop)}(tunable_buffer, disc_buffer, const_buffer, dep_buffer,
150161 nonnumeric_buffer, update_function_iip, update_function_oop)
151- if mtkps. dependent_update_iip != = nothing
152- mtkps. dependent_update_iip (ArrayPartition (mtkps. dependent), mtkps... )
153- end
154162 return mtkps
155163end
156164
165+ function narrow_buffer_type (buffer:: AbstractArray )
166+ type = Union{}
167+ for x in buffer
168+ type = promote_type (type, typeof (x))
169+ end
170+ return convert .(type, buffer)
171+ end
172+
173+ function narrow_buffer_type (buffer:: AbstractArray{<:AbstractArray} )
174+ buffer = narrow_buffer_type .(buffer)
175+ type = Union{}
176+ for x in buffer
177+ type = promote_type (type, eltype (x))
178+ end
179+ return broadcast .(convert, type, buffer)
180+ end
181+
157182function buffer_to_arraypartition (buf)
158183 return ArrayPartition (ntuple (i -> _buffer_to_arrp_helper (buf[i]), Val (length (buf))))
159184end
@@ -550,3 +575,17 @@ function as_duals(p::MTKParameters, dualtype)
550575 discrete = dualtype .(p. discrete)
551576 return MTKParameters {typeof(tunable), typeof(discrete)} (tunable, discrete)
552577end
578+
579+ const MISSING_PARAMETERS_MESSAGE = """
580+ Some parameters are missing from the variable map.
581+ Please provide a value or default for the following variables:
582+ """
583+
584+ struct MissingParametersError <: Exception
585+ vars:: Any
586+ end
587+
588+ function Base. showerror (io:: IO , e:: MissingParametersError )
589+ println (io, MISSING_PARAMETERS_MESSAGE)
590+ println (io, e. vars)
591+ end
0 commit comments