Skip to content

Commit b1818ba

Browse files
refactor: remove container_type kwarg of MTKParameters, use p_constructor
1 parent 5189598 commit b1818ba

File tree

2 files changed

+17
-31
lines changed

2 files changed

+17
-31
lines changed

src/systems/parameter_buffer.jl

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,7 @@ the default behavior).
2929
function MTKParameters(
3030
sys::AbstractSystem, p, u0 = Dict(); tofloat = false,
3131
t0 = nothing, substitution_limit = 1000, floatT = nothing,
32-
container_type = Vector, p_constructor = identity)
33-
if !(container_type <: AbstractArray)
34-
throw(ArgumentError("""
35-
`container_type` for `MTKParameters` must be a subtype of `AbstractArray`. Found \
36-
$container_type.
37-
"""))
38-
end
32+
p_constructor = identity)
3933
ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing
4034
get_index_cache(sys)
4135
else
@@ -140,22 +134,19 @@ function MTKParameters(
140134
end
141135
end
142136
end
143-
tunable_buffer = p_constructor(narrow_buffer_type(tunable_buffer; container_type))
137+
tunable_buffer = narrow_buffer_type(tunable_buffer; p_constructor)
144138
if isempty(tunable_buffer)
145139
tunable_buffer = SizedVector{0, Float64}()
146140
end
147-
initials_buffer = p_constructor(narrow_buffer_type(initials_buffer; container_type))
141+
initials_buffer = narrow_buffer_type(initials_buffer; p_constructor)
148142
if isempty(initials_buffer)
149143
initials_buffer = SizedVector{0, Float64}()
150144
end
151-
disc_buffer = p_constructor.(narrow_buffer_type.(disc_buffer; container_type))
152-
const_buffer = p_constructor.(narrow_buffer_type.(const_buffer; container_type))
145+
disc_buffer = narrow_buffer_type.(disc_buffer; p_constructor)
146+
const_buffer = narrow_buffer_type.(const_buffer; p_constructor)
153147
# Don't narrow nonnumeric types
154148
if !isempty(nonnumeric_buffer)
155-
nonnumeric_buffer = map(nonnumeric_buffer) do buf
156-
p_constructor(SymbolicUtils.Code.create_array(
157-
container_type, eltype(buf), Val(1), Val(length(buf)), buf...))
158-
end
149+
nonnumeric_buffer = map(p_constructor, nonnumeric_buffer)
159150
end
160151

161152
mtkps = MTKParameters{
@@ -172,45 +163,40 @@ function rebuild_with_caches(p::MTKParameters, cache_templates::BufferTemplate..
172163
@set p.caches = buffers
173164
end
174165

175-
function narrow_buffer_type(buffer::AbstractArray; container_type = typeof(buffer))
166+
function narrow_buffer_type(buffer::AbstractArray; p_constructor = identity)
176167
type = Union{}
177168
for x in buffer
178169
type = promote_type(type, typeof(x))
179170
end
180-
return SymbolicUtils.Code.create_array(
181-
container_type, type, Val(ndims(buffer)), Val(length(buffer)), buffer...)
171+
return p_constructor(type.(buffer))
182172
end
183173

184174
function narrow_buffer_type(
185-
buffer::AbstractArray{<:AbstractArray}; container_type = typeof(buffer))
175+
buffer::AbstractArray{<:AbstractArray}; p_constructor = identity)
186176
type = Union{}
187177
for arr in buffer
188178
for x in arr
189179
type = promote_type(type, typeof(x))
190180
end
191181
end
192182
buffer = map(buffer) do buf
193-
SymbolicUtils.Code.create_array(
194-
container_type, type, Val(ndims(buf)), Val(size(buf)), buf...)
183+
p_constructor(type.(buf))
195184
end
196-
return SymbolicUtils.Code.create_array(
197-
container_type, nothing, Val(ndims(buffer)), Val(size(buffer)), buffer...)
185+
return p_constructor(buffer)
198186
end
199187

200-
function narrow_buffer_type(buffer::BlockedArray; container_type = typeof(parent(buffer)))
188+
function narrow_buffer_type(buffer::BlockedArray; p_constructor = identity)
201189
if eltype(buffer) <: AbstractArray
202-
buffer = narrow_buffer_type.(buffer; container_type)
190+
buffer = narrow_buffer_type.(buffer; p_constructor)
203191
end
204192
type = Union{}
205193
for x in buffer
206194
type = promote_type(type, typeof(x))
207195
end
208-
tmp = SymbolicUtils.Code.create_array(
209-
container_type, type, Val(ndims(buffer)), Val(size(buffer)), buffer...)
196+
tmp = p_constructor(type.(buffer))
210197
blocks = ntuple(Val(ndims(buffer))) do i
211198
bsizes = blocksizes(buffer, i)
212-
SymbolicUtils.Code.create_array(
213-
container_type, Int, Val(1), Val(length(bsizes)), bsizes...)
199+
p_constructor(Int.(bsizes))
214200
end
215201
return BlockedArray(tmp, blocks...)
216202
end

test/mtkparameters.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ end
2727
@test getp(sys, a)(ps) == getp(sys, b)(ps) == getp(sys, c)(ps) == 0.0
2828
@test getp(sys, d)(ps) isa Int
2929

30-
@testset "`container_type`" begin
31-
ps2 = MTKParameters(sys, ivs; container_type = SVector)
30+
@testset "`p_constructor`" begin
31+
ps2 = MTKParameters(sys, ivs; p_constructor = x -> SArray{Tuple{size(x)...}}(x))
3232
@test ps2.tunable isa SVector
3333
@test ps2.initials isa SVector
3434
@test ps2.discrete isa Tuple{<:BlockedVector{Float64, <:SVector}}

0 commit comments

Comments
 (0)