@@ -29,13 +29,7 @@ the default behavior).
2929function MTKParameters (
3030 sys:: AbstractSystem , p, u0 = Dict (); tofloat = false ,
3131 t0 = nothing , substitution_limit = 1000 , floatT = nothing ,
32- container_type = Vector, p_constructor = identity)
33- if ! (container_type <: AbstractArray )
34- throw (ArgumentError ("""
35- `container_type` for `MTKParameters` must be a subtype of `AbstractArray`. Found \
36- $container_type .
37- """ ))
38- end
32+ p_constructor = identity)
3933 ic = if has_index_cache (sys) && get_index_cache (sys) != = nothing
4034 get_index_cache (sys)
4135 else
@@ -140,22 +134,19 @@ function MTKParameters(
140134 end
141135 end
142136 end
143- tunable_buffer = p_constructor ( narrow_buffer_type (tunable_buffer; container_type) )
137+ tunable_buffer = narrow_buffer_type (tunable_buffer; p_constructor )
144138 if isempty (tunable_buffer)
145139 tunable_buffer = SizedVector {0, Float64} ()
146140 end
147- initials_buffer = p_constructor ( narrow_buffer_type (initials_buffer; container_type) )
141+ initials_buffer = narrow_buffer_type (initials_buffer; p_constructor )
148142 if isempty (initials_buffer)
149143 initials_buffer = SizedVector {0, Float64} ()
150144 end
151- disc_buffer = p_constructor .( narrow_buffer_type .(disc_buffer; container_type) )
152- const_buffer = p_constructor .( narrow_buffer_type .(const_buffer; container_type) )
145+ disc_buffer = narrow_buffer_type .(disc_buffer; p_constructor )
146+ const_buffer = narrow_buffer_type .(const_buffer; p_constructor )
153147 # Don't narrow nonnumeric types
154148 if ! isempty (nonnumeric_buffer)
155- nonnumeric_buffer = map (nonnumeric_buffer) do buf
156- p_constructor (SymbolicUtils. Code. create_array (
157- container_type, eltype (buf), Val (1 ), Val (length (buf)), buf... ))
158- end
149+ nonnumeric_buffer = map (p_constructor, nonnumeric_buffer)
159150 end
160151
161152 mtkps = MTKParameters{
@@ -172,45 +163,40 @@ function rebuild_with_caches(p::MTKParameters, cache_templates::BufferTemplate..
172163 @set p. caches = buffers
173164end
174165
175- function narrow_buffer_type (buffer:: AbstractArray ; container_type = typeof (buffer) )
166+ function narrow_buffer_type (buffer:: AbstractArray ; p_constructor = identity )
176167 type = Union{}
177168 for x in buffer
178169 type = promote_type (type, typeof (x))
179170 end
180- return SymbolicUtils. Code. create_array (
181- container_type, type, Val (ndims (buffer)), Val (length (buffer)), buffer... )
171+ return p_constructor (type .(buffer))
182172end
183173
184174function narrow_buffer_type (
185- buffer:: AbstractArray{<:AbstractArray} ; container_type = typeof (buffer) )
175+ buffer:: AbstractArray{<:AbstractArray} ; p_constructor = identity )
186176 type = Union{}
187177 for arr in buffer
188178 for x in arr
189179 type = promote_type (type, typeof (x))
190180 end
191181 end
192182 buffer = map (buffer) do buf
193- SymbolicUtils. Code. create_array (
194- container_type, type, Val (ndims (buf)), Val (size (buf)), buf... )
183+ p_constructor (type .(buf))
195184 end
196- return SymbolicUtils. Code. create_array (
197- container_type, nothing , Val (ndims (buffer)), Val (size (buffer)), buffer... )
185+ return p_constructor (buffer)
198186end
199187
200- function narrow_buffer_type (buffer:: BlockedArray ; container_type = typeof ( parent (buffer)) )
188+ function narrow_buffer_type (buffer:: BlockedArray ; p_constructor = identity )
201189 if eltype (buffer) <: AbstractArray
202- buffer = narrow_buffer_type .(buffer; container_type )
190+ buffer = narrow_buffer_type .(buffer; p_constructor )
203191 end
204192 type = Union{}
205193 for x in buffer
206194 type = promote_type (type, typeof (x))
207195 end
208- tmp = SymbolicUtils. Code. create_array (
209- container_type, type, Val (ndims (buffer)), Val (size (buffer)), buffer... )
196+ tmp = p_constructor (type .(buffer))
210197 blocks = ntuple (Val (ndims (buffer))) do i
211198 bsizes = blocksizes (buffer, i)
212- SymbolicUtils. Code. create_array (
213- container_type, Int, Val (1 ), Val (length (bsizes)), bsizes... )
199+ p_constructor (Int .(bsizes))
214200 end
215201 return BlockedArray (tmp, blocks... )
216202end
0 commit comments