@@ -28,7 +28,11 @@ the default behavior).
2828"""
2929function MTKParameters (
3030 sys:: AbstractSystem , p, u0 = Dict (); tofloat = false ,
31- t0 = nothing , substitution_limit = 1000 , floatT = nothing )
31+ t0 = nothing , substitution_limit = 1000 , floatT = nothing ,
32+ container_type = Vector)
33+ if ! (container_type <: AbstractArray )
34+ container_type = Array
35+ end
3236 ic = if has_index_cache (sys) && get_index_cache (sys) != = nothing
3337 get_index_cache (sys)
3438 else
@@ -133,18 +137,23 @@ function MTKParameters(
133137 end
134138 end
135139 end
136- tunable_buffer = narrow_buffer_type (tunable_buffer)
140+ tunable_buffer = narrow_buffer_type (tunable_buffer; container_type )
137141 if isempty (tunable_buffer)
138142 tunable_buffer = SizedVector {0, Float64} ()
139143 end
140- initials_buffer = narrow_buffer_type (initials_buffer)
144+ initials_buffer = narrow_buffer_type (initials_buffer; container_type )
141145 if isempty (initials_buffer)
142146 initials_buffer = SizedVector {0, Float64} ()
143147 end
144- disc_buffer = narrow_buffer_type .(disc_buffer)
145- const_buffer = narrow_buffer_type .(const_buffer)
148+ disc_buffer = narrow_buffer_type .(disc_buffer; container_type )
149+ const_buffer = narrow_buffer_type .(const_buffer; container_type )
146150 # Don't narrow nonnumeric types
147- nonnumeric_buffer = nonnumeric_buffer
151+ if ! isempty (nonnumeric_buffer)
152+ nonnumeric_buffer = map (nonnumeric_buffer) do buf
153+ SymbolicUtils. Code. create_array (
154+ container_type, nothing , Val (1 ), Val (length (buf)), buf... )
155+ end
156+ end
148157
149158 mtkps = MTKParameters{
150159 typeof (tunable_buffer), typeof (initials_buffer), typeof (disc_buffer),
@@ -160,21 +169,44 @@ function rebuild_with_caches(p::MTKParameters, cache_templates::BufferTemplate..
160169 @set p. caches = buffers
161170end
162171
163- function narrow_buffer_type (buffer:: AbstractArray )
172+ function narrow_buffer_type (buffer:: AbstractArray ; container_type = typeof (buffer) )
164173 type = Union{}
165174 for x in buffer
166175 type = promote_type (type, typeof (x))
167176 end
168- return convert .(type, buffer)
177+ return SymbolicUtils. Code. create_array (
178+ container_type, type, Val (ndims (buffer)), Val (length (buffer)), buffer... )
169179end
170180
171- function narrow_buffer_type (buffer:: AbstractArray{<:AbstractArray} )
172- buffer = narrow_buffer_type .(buffer)
181+ function narrow_buffer_type (
182+ buffer:: AbstractArray{<:AbstractArray} ; container_type = typeof (buffer))
183+ type = Union{}
184+ for arr in buffer
185+ for x in arr
186+ type = promote_type (type, typeof (x))
187+ end
188+ end
189+ buffer = map (buffer) do buf
190+ SymbolicUtils. Code. create_array (
191+ container_type, type, Val (ndims (buf)), Val (size (buf)), buf... )
192+ end
193+ return SymbolicUtils. Code. create_array (
194+ container_type, nothing , Val (ndims (buffer)), Val (size (buffer)), buffer... )
195+ end
196+
197+ function narrow_buffer_type (buffer:: BlockedArray ; container_type = typeof (parent (buffer)))
173198 type = Union{}
174199 for x in buffer
175- type = promote_type (type, eltype (x))
200+ type = promote_type (type, typeof (x))
201+ end
202+ tmp = SymbolicUtils. Code. create_array (
203+ container_type, type, Val (ndims (buffer)), Val (size (buffer)), buffer... )
204+ blocks = ntuple (Val (ndims (buffer))) do i
205+ bsizes = blocksizes (buffer, i)
206+ SymbolicUtils. Code. create_array (
207+ container_type, Int, Val (1 ), Val (length (bsizes)), bsizes... )
176208 end
177- return broadcast .(convert, type, buffer )
209+ return BlockedArray (tmp, blocks ... )
178210end
179211
180212function buffer_to_arraypartition (buf)
0 commit comments