Skip to content

Commit 2588d11

Browse files
authored
Merge pull request #243 from JuliaDynamics/hw/workshop
updates for workshop
2 parents 05b4b9f + fbc7f71 commit 2588d11

File tree

11 files changed

+38
-16
lines changed

11 files changed

+38
-16
lines changed

NetworkDynamicsInspector/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ Electron = "a1bb12fb-d4d1-54b4-b10a-ee7951ef7ad3"
2323
NetworkDynamics = {path = ".."}
2424

2525
[extensions]
26-
ElectronExt = ["Electron"]
26+
NetworkDynamicsInspectorElectronExt = ["Electron"]
2727

2828
[compat]
2929
Bonito = "4.0.0"

NetworkDynamicsInspector/assets/app.css

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ span {
2020
.maingrid {
2121
display: grid;
2222
grid-template-columns: min-content auto;
23+
align-items: start; /* make sure ts col is as short as possible for resize event */
2324
width: 100%;
2425
gap: 0px;
2526
}

NetworkDynamicsInspector/ext/ElectronExt.jl renamed to NetworkDynamicsInspector/ext/NetworkDynamicsInspectorElectronExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module ElectronExt
1+
module NetworkDynamicsInspectorElectronExt
22

33
using NetworkDynamicsInspector: NetworkDynamicsInspector as NDI
44
using Electron: Electron, windows

NetworkDynamicsInspector/src/NetworkDynamicsInspector.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using NetworkDynamics: NetworkDynamics, SII, EIndex, VIndex, Network,
77
obssym, psym, sym, extract_nw
88

99
using Graphs: nv, ne
10-
using WGLMakie: WGLMakie
10+
using WGLMakie: WGLMakie, WithConfig
1111
using WGLMakie.Makie: Makie, @lift, MouseEvent, Point2f, with_theme,
1212
lines!, vlines!, Theme, Figure, Colorbar, Axis,
1313
xlims!, ylims!, autolimits!, hidespines!, hidedecorations!,
@@ -65,7 +65,6 @@ function get_webapp(app)
6565
@info "GUI Session updated"
6666
end
6767

68-
WGLMakie.activate!(resize_to=:parent)
6968
clear_obs!(app)
7069

7170
resize_gp = js"""

NetworkDynamicsInspector/src/graphplot.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ function graphplot_card(app, session)
220220
<li><strong>Ctrl + Click</strong> resets axis after zoom</li>
221221
</ul>
222222
""")
223-
Card([fig, help]; class="bonito-card graphplot-card")
223+
Card([WithConfig(fig; resize_to=:parent), help]; class="bonito-card graphplot-card")
224224
end
225225
function _gracefully_extract_states!(vec, sol, t, idxs, rel)
226226
isvalid(s) = SII.is_variable(sol, s) || SII.is_parameter(sol, s) || SII.is_observed(sol, s)
@@ -379,7 +379,7 @@ function gpstate_control_card(app, type)
379379

380380
childs = Any[DOM.div(
381381
selector,
382-
DOM.div(fig; style=Styles("height" => "40px")),
382+
DOM.div(WithConfig(fig; resize_to=:parent); style=Styles("height" => "40px")),
383383
# fig,
384384
# RoundedLabel(@lift $maxrange[1]; style=Styles("text-align"=>"right")),
385385
cslider,

NetworkDynamicsInspector/src/timeseries.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ function timeseries_card(app, key, session)
444444
card = Card(
445445
[DOM.div(
446446
comp_state_sel_dom,
447-
DOM.div(fig; class="timeseries-axis-container"),
447+
DOM.div(WithConfig(fig; resize_to=:parent); class="timeseries-axis-container"),
448448
closebutton(app, key);
449449
class="timeseries-card-container"
450450
), help];

src/aggregators.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ end
2727
NaiveAggregator(f) = (im, batches) -> NaiveAggregator(im, batches, f)
2828

2929
function aggregate!(a::NaiveAggregator, aggbuf, data)
30-
fill!(aggbuf, zero(eltype(aggbuf)))
30+
fill!(aggbuf, _appropriate_zero(aggbuf))
3131
_aggregate!(a, a.batches, aggbuf, data)
3232
end
3333
function _aggregate!(a::NaiveAggregator, batches, aggbuf, data)
@@ -101,7 +101,7 @@ KAAggregator(im, batches, f) = KAAggregator(f, AggregationMap(im, batches))
101101

102102
function aggregate!(a::KAAggregator, aggbuf, data)
103103
am = a.m
104-
fill!(aggbuf, zero(eltype(aggbuf)))
104+
fill!(aggbuf, _appropriate_zero(aggbuf))
105105
_backend = get_backend(data)
106106
# kernel = agg_kernel!(_backend, 1024, length(am.map))
107107
# kernel(a.f, aggbuf, view(data, am.range), am.map)
@@ -140,7 +140,7 @@ SequentialAggregator(f) = (im, batches) -> SequentialAggregator(im, batches, f)
140140
SequentialAggregator(im, batches, f) = SequentialAggregator(f, AggregationMap(im, batches))
141141

142142
function aggregate!(a::SequentialAggregator, aggbuf, data)
143-
fill!(aggbuf, zero(eltype(aggbuf)))
143+
fill!(aggbuf, _appropriate_zero(aggbuf))
144144

145145
am = a.m
146146
@inbounds begin
@@ -169,7 +169,7 @@ PolyesterAggregator(im, batches, f) = PolyesterAggregator(f, _inv_aggregation_ma
169169

170170
function aggregate!(a::PolyesterAggregator, aggbuf, data)
171171
length(a.m) == length(aggbuf) || throw(DimensionMismatch("length of aggbuf and a.m must be equal"))
172-
fill!(aggbuf, zero(eltype(aggbuf)))
172+
fill!(aggbuf, _appropriate_zero(aggbuf))
173173

174174
maxdepth = mapreduce(x -> length(x[2]), max, a.m)
175175

@@ -196,7 +196,7 @@ ThreadedAggregator(im, batches, f) = ThreadedAggregator(f, _inv_aggregation_map(
196196

197197
function aggregate!(a::ThreadedAggregator, aggbuf, data)
198198
length(a.m) == length(aggbuf) || throw(DimensionMismatch("length of aggbuf and a.m must be equal"))
199-
fill!(aggbuf, zero(eltype(aggbuf)))
199+
fill!(aggbuf, _appropriate_zero(aggbuf))
200200

201201
Threads.@threads for (dstidx, srcidxs) in a.m
202202
@inbounds for srcidx in srcidxs
@@ -299,3 +299,11 @@ get_aggr_constructor(a::SparseAggregator) = SparseAggregator(+)
299299

300300
iscudacompatible(::Type{<:KAAggregator}) = true
301301
iscudacompatible(::Type{<:SparseAggregator}) = true
302+
303+
function _appropriate_zero(x)
304+
if isconcretetype(eltype(x))
305+
zero(eltype(x))
306+
else
307+
0.0 # hopefully that casts to what is needed
308+
end
309+
end

src/metadata.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ Sets the callback function for the component. Overwrites any existing callback.
393393
See also [`add_callback!`](@ref).
394394
"""
395395
function set_callback!(c::ComponentModel, cb; check=true)
396-
if !(cb isa ComponentCallback) && !(cb isa NTuple{N, <:ComponentCallback} where N)
396+
if !(cb isa ComponentCallback) && !(cb isa Tuple && all(c -> c isa ComponentCallback, cb))
397397
throw(ArgumentError("Callback must be a ComponentCallback or a tuple of ComponentCallbacks, got $(typeof(cb))."))
398398
end
399399
check && assert_cb_compat(c, cb)

src/symbolicindexing.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,10 +555,11 @@ function _get_observed_f(im::IndexManager, cf::VertexModel, vidx, _obsf::O) wher
555555
aggr = im.v_aggr[vidx]
556556
extr = im.v_out[vidx]
557557
pr = im.v_para[vidx]
558-
ret = Vector{Float64}(undef, N)
558+
retcache = DiffCache(Vector{Float64}(undef, N))
559559
_hasext = has_external_input(cf)
560560

561561
(u, outbuf, aggbuf, extbuf, p, t) -> begin
562+
ret = PreallocationTools.get_tmp(retcache, first(u)*first(p)*first(t))
562563
ins = if _hasext
563564
(view(aggbuf, aggr), view(extbuf, extr))
564565
else

test/aggregators_test.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ using Chairmarks
66
using InteractiveUtils
77
using Test
88
using StableRNGs
9+
using ForwardDiff: Dual
10+
using Symbolics
911

1012
(isinteractive() && @__MODULE__()==Main ? includet : include)("ComponentLibrary.jl")
1113

@@ -61,3 +63,14 @@ using StableRNGs
6163
@test issame
6264
end
6365
end
66+
67+
@testset "Test _appropriate_zero" begin
68+
@test NetworkDynamics._appropriate_zero([1,2,3]) isa Int
69+
@test NetworkDynamics._appropriate_zero([1,2,3]) == 0
70+
@test NetworkDynamics._appropriate_zero([1.0,2,3]) isa Float64
71+
@test NetworkDynamics._appropriate_zero([1.0,2,3]) == 0.0
72+
@test NetworkDynamics._appropriate_zero([Dual(1.0), 2, 3]) == Dual(0.0)
73+
@variables x, y, z
74+
@test NetworkDynamics._appropriate_zero([x,y,z]) isa Num
75+
@test NetworkDynamics._appropriate_zero(Any[x,y,z]) isa Float64
76+
end

0 commit comments

Comments
 (0)