Skip to content

Commit f50ebe4

Browse files
committed
fix tests
1 parent c7dbfea commit f50ebe4

File tree

5 files changed

+33
-28
lines changed

5 files changed

+33
-28
lines changed

src/component_functions.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,11 +1121,12 @@ function Base.:(==)(cf1::ComponentModel, cf2::ComponentModel)
11211121
typeof(cf1) == typeof(cf2) && equal_fields(cf1, cf2)
11221122
end
11231123

1124-
function compfg(c)
1124+
# force specialization on f, g, fft
1125+
compfg(c) = _compfg(compf(c), compg(c), fftype(c))
1126+
function _compfg(f::F, g::G, fft::FFT) where {F, G, FFT}
11251127
(outs, du, u, ins, p, t) -> begin
1126-
f = compf(c)
11271128
isnothing(f) || f(du, u, ins..., p, t)
1128-
compg(c)(_gargs(fftype(c), outs, du, u, ins, p, t)...)
1129+
g(_gargs(fft, outs, du, u, ins, p, t)...)
11291130
nothing
11301131
end
11311132
end

src/metadata.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ function _get_initial_observed(cf)
318318
(get_defaults_or_inits(cf, insym(cf); missing_val), )
319319
end
320320
p = get_defaults_or_inits(cf, psym(cf); missing_val)
321-
cf.obsf(obs, u, ins..., p, NaN)
321+
obsf(cf)(obs, u, ins..., p, NaN)
322322
obs
323323
end
324324

src/show.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Base.show(io::IO, s::SequentialAggregator) = print(io, "SequentialAggregator($(r
4242
Base.show(io::IO, s::PolyesterAggregator) = print(io, "PolyesterAggregator($(repr(s.f)))")
4343

4444
function Base.show(io::IO, ::MIME"text/plain", c::ComponentModel)
45-
type = match(r"^(.*?)\{", string(typeof(c)))[1]
45+
type = string(typeof(c))
4646
print(io, type, styled" {NetworkDynamics_name::$(c.name)}")
4747
print(io, styled" {NetworkDynamics_fftype:$(fftype(c))}")
4848
if has_graphelement(c)

src/symbolicindexing.jl

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -531,40 +531,44 @@ function _is_normalized(snis)
531531
end
532532
_is_normalized(snis::SymbolicIndex) = SII.symbolic_type(snis) === SII.ScalarSymbolic()
533533

534-
function _get_observed_f(nw::Network, cf::VertexModel, vidx)
534+
# function barrier for obsf
535+
_get_observed_f(nw, cf, vidx) = _get_observed_f(nw.im, cf, vidx, obsf(cf))
536+
function _get_observed_f(im::IndexManager, cf::VertexModel, vidx, _obsf::O) where {O}
535537
N = length(cf.obssym)
536-
ur = nw.im.v_data[vidx]
537-
aggr = nw.im.v_aggr[vidx]
538-
extr = nw.im.v_out[vidx]
539-
pr = nw.im.v_para[vidx]
538+
ur = im.v_data[vidx]
539+
aggr = im.v_aggr[vidx]
540+
extr = im.v_out[vidx]
541+
pr = im.v_para[vidx]
540542
ret = Vector{Float64}(undef, N)
543+
_hasext = has_external_input(cf)
541544

542-
@closure (u, outbuf, aggbuf, extbuf, p, t) -> begin
543-
ins = if has_external_input(cf)
545+
(u, outbuf, aggbuf, extbuf, p, t) -> begin
546+
ins = if _hasext
544547
(view(aggbuf, aggr), view(extbuf, extr))
545548
else
546549
(view(aggbuf, aggr), )
547550
end
548-
cf.obsf(ret, view(u, ur), ins..., view(p, pr), t)
551+
_obsf(ret, view(u, ur), ins..., view(p, pr), t)
549552
ret
550553
end
551554
end
552-
553-
function _get_observed_f(nw::Network, cf::EdgeModel, eidx)
555+
function _get_observed_f(im::IndexManager, cf::EdgeModel, eidx, _obsf::O) where {O}
554556
N = length(cf.obssym)
555-
ur = nw.im.e_data[eidx]
556-
esrcr = nw.im.v_out[nw.im.edgevec[eidx].src]
557-
edstr = nw.im.v_out[nw.im.edgevec[eidx].dst]
558-
extr = nw.im.e_out[eidx]
559-
pr = nw.im.e_para[eidx]
557+
ur = im.e_data[eidx]
558+
esrcr = im.v_out[im.edgevec[eidx].src]
559+
edstr = im.v_out[im.edgevec[eidx].dst]
560+
extr = im.e_out[eidx]
561+
pr = im.e_para[eidx]
560562
ret = Vector{Float64}(undef, N)
563+
_hasext = has_external_input(cf)
561564

562-
@closure (u, outbuf, aggbuf, extbuf, p, t) -> begin
563-
ins = (view(outbuf, esrcr), view(outbuf, edstr))
564-
if has_external_input(cf)
565-
(ins..., view(extbuf, extr))
565+
(u, outbuf, aggbuf, extbuf, p, t) -> begin
566+
ins = if _hasext
567+
(view(outbuf, esrcr), view(outbuf, edstr), view(extbuf, extr))
568+
else
569+
(view(outbuf, esrcr), view(outbuf, edstr))
566570
end
567-
cf.obsf(ret, view(u, ur), ins..., view(p, pr), t)
571+
_obsf(ret, view(u, ur), ins..., view(p, pr), t)
568572
ret
569573
end
570574
end

test/symbolicindexing_test.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ end
525525
idxs1 = [VIndex(1,), VIndex(2, :Pdamping), EIndex(1,:P), VIndex(2,:P)]
526526
idxs2 = [VIndex(1,), VIndex(2,)]
527527
# full call
528-
# @b $s[$idxs1] # 134 106 94
528+
# @b $s[$idxs1] # 134 106 94 101
529529
# @b $s[$idxs2] # 31 31 34
530530

531531
# scalar call
@@ -534,9 +534,9 @@ end
534534
# @b SII.observed($nw, $(VIndex(2,:Pdamping))) # 15
535535
# @b SII.observed($nw, $(VIndex(2,:θ))) # 7 5
536536

537-
b = @b SII.observed($nw, $idxs1) # 69 36 42 30
537+
b = @b SII.observed($nw, $idxs1) # 69 36 42 30 37
538538
if VERSION v"1.11"
539-
@test b.allocs <= 30
539+
@test b.allocs <= 37
540540
end
541541
b = @b SII.observed($nw, $idxs2) # 12 7 10 5
542542
if VERSION v"1.11"

0 commit comments

Comments
 (0)