Skip to content

Commit 9130c53

Browse files
committed
create batches based on hash for faster comparison
1 parent 7561430 commit 9130c53

File tree

5 files changed

+72
-50
lines changed

5 files changed

+72
-50
lines changed

src/callbacks.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ function wrap_component_callbacks(nw)
235235
end
236236
# group the callbacks such that they are in groups which are "batchequal"
237237
# batchequal groups can be wrapped into a single callback
238-
idx_per_type = _find_identical(callbacks, 1:length(components))
238+
idx_per_type = find_identical(callbacks; equality=_batchequal)
239239
batches = []
240240
for typeidx in idx_per_type
241241
batchcomps = components[typeidx]
@@ -251,21 +251,22 @@ function wrap_component_callbacks(nw)
251251
end
252252
return batches
253253
end
254-
function batchequal(a::ContinousComponentCallback, b::ContinousComponentCallback)
255-
batchequal(a.condition, b.condition) || return false
256-
batchequal(a.kwargs, b.kwargs) || return false
254+
_batchequal(a, b) = false
255+
function _batchequal(a::ContinousComponentCallback, b::ContinousComponentCallback)
256+
_batchequal(a.condition, b.condition) || return false
257+
_batchequal(a.kwargs, b.kwargs) || return false
257258
return true
258259
end
259-
function batchequal(a::VectorContinousComponentCallback, b::VectorContinousComponentCallback)
260-
batchequal(a.condition, b.condition) || return false
261-
batchequal(a.kwargs, b.kwargs) || return false
260+
function _batchequal(a::VectorContinousComponentCallback, b::VectorContinousComponentCallback)
261+
_batchequal(a.condition, b.condition) || return false
262+
_batchequal(a.kwargs, b.kwargs) || return false
262263
a.len == b.len || return false
263264
return true
264265
end
265-
function batchequal(a::T, b::T) where {T <: Union{ComponentCondition, ComponentAffect}}
266+
function _batchequal(a::T, b::T) where {T <: Union{ComponentCondition, ComponentAffect}}
266267
typeof(a) == typeof(b)
267268
end
268-
function batchequal(a::NamedTuple, b::NamedTuple)
269+
function _batchequal(a::NamedTuple, b::NamedTuple)
269270
length(a) == length(b) || return false
270271
for (k, v) in a
271272
haskey(b, k) || return false

src/component_functions.jl

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -579,19 +579,6 @@ dispatchT(::T) where {T<:ComponentModel} = dispatchT(T)
579579
dispatchT(T::Type{<:VertexModel}) = VertexModel
580580
dispatchT(T::Type{<:EdgeModel}) = EdgeModel
581581

582-
# TODO: introduce batchequal hash for faster batching of component models
583-
batchequal(a, b) = false
584-
function batchequal(a::ComponentModel, b::ComponentModel)
585-
compf(a) === compf(b) || return false
586-
compg(a) === compg(b) || return false
587-
fftype(a) == fftype(b) || return false
588-
dim(a) == dim(b) || return false
589-
outdim(a) == outdim(b) || return false
590-
pdim(a) == pdim(b) || return false
591-
extdim(a) == extdim(b) || return false
592-
return true
593-
end
594-
595582
"""
596583
_construct_comp(::Type{T}, kwargs) where {T}
597584

src/construction.jl

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,18 @@ function Network(g::AbstractGraph,
128128

129129
# batch identical edge and vertex model
130130
@timeit_debug "batch identical vertexes" begin
131-
vidxs = _find_identical(vertexm, 1:nv(g))
131+
vidxs = if all_same_v
132+
[collect(1:nv(g))]
133+
else
134+
_find_identical_components(_vertexm)
135+
end
132136
end
133137
@timeit_debug "batch identical edges" begin
134-
eidxs = _find_identical(edgem, 1:ne(g))
138+
eidxs = if all_same_e
139+
[collect(1:ne(g))]
140+
else
141+
_find_identical_components(_edgem)
142+
end
135143
end
136144

137145
# create vertex batches and initialize with index manager
@@ -198,6 +206,26 @@ function Network(g::AbstractGraph,
198206
return nw
199207
end
200208

209+
function _find_identical_components(models)
210+
# identical components are based on identical _component_hash
211+
# those can have different metadata but are considered identical when in comes to batching
212+
hashs = _component_hash.(models)
213+
find_identical(hashs)
214+
end
215+
# hash condition: components with same hash will end up in the same batch
216+
function _component_hash(c::ComponentModel)
217+
hash((
218+
typeof(c), # same type
219+
compf(c), # same f-function
220+
compg(c), # same g-function
221+
fftype(c), # same feedforward type
222+
dim(c), # same state dimension
223+
outdim(c), # same output dimension
224+
pdim(c), # same parameter dimension
225+
extdim(c), # same external input dimension
226+
))
227+
end
228+
201229
function Network(vertexfs, edgefs; kwargs...)
202230
vertexfs = vertexfs isa VertexModel ? [vertexfs] : vertexfs
203231
edgefs = edgefs isa EdgeModel ? [edgefs] : edgefs
@@ -324,28 +352,6 @@ function batch_by_idxs(v::AbstractVector, batches::Vector{Vector{Int}})
324352
[v[batch] for batch in batches]
325353
end
326354

327-
_find_identical(v::ComponentModel, indices) = [collect(indices)]
328-
function _find_identical(v::Vector{T}, indices) where {T<:ComponentModel}
329-
idxs_per_type = Vector{Int}[]
330-
unique_comp = T[]
331-
for i in eachindex(v)
332-
found = false
333-
for j in eachindex(unique_comp)
334-
if batchequal(v[i], unique_comp[j])
335-
found = true
336-
push!(idxs_per_type[j], indices[i])
337-
break
338-
end
339-
end
340-
if !found
341-
push!(unique_comp, v[i])
342-
push!(idxs_per_type, [indices[i]])
343-
end
344-
end
345-
@assert length(unique_comp) == length(idxs_per_type)
346-
return idxs_per_type
347-
end
348-
349355
function construct_mass_matrix(im; type=nothing)
350356
vertexd = filter(pairs(im.vertexm)) do (_, c)
351357
hasproperty(c,:mass_matrix) && c.mass_matrix != LinearAlgebra.I

src/utils.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,31 @@ abstract type SymbolicParameterIndex{C,S} <: SymbolicIndex{C,S} end
140140

141141
flatten_sym(v::NamedTuple) = reduce(vcat, values(v))
142142
flatten_sym(v::AbstractVector{Symbol}) = v
143+
144+
"""
145+
find_identical(v::Vector;; equality)
146+
147+
Find identical elements in a vector `v` using the `equality` function.
148+
Returns a vector of vectors where each vector contains the indices of identical elements.
149+
"""
150+
function find_identical(v::Vector{T}; equality=isequal) where {T}
151+
indices = eachindex(v)
152+
idxs_per_type = Vector{Int}[]
153+
unique_comp = T[]
154+
for i in eachindex(v)
155+
found = false
156+
for j in eachindex(unique_comp)
157+
if equality(v[i], unique_comp[j])
158+
found = true
159+
push!(idxs_per_type[j], indices[i])
160+
break
161+
end
162+
end
163+
if !found
164+
push!(unique_comp, v[i])
165+
push!(idxs_per_type, [indices[i]])
166+
end
167+
end
168+
@assert length(unique_comp) == length(idxs_per_type)
169+
return idxs_per_type
170+
end

test/utils_test.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,23 @@ using NetworkDynamics
3737
end
3838

3939
@testset "find_identical" begin
40-
using NetworkDynamics: _find_identical
40+
using NetworkDynamics: _find_identical_components
4141

4242
v1 = Lib.kuramoto_second()
43-
@test _find_identical(v1, 1:10) == [collect(1:10)]
43+
@test _find_identical_components([v1 for _ in 1:10]) == [collect(1:10)]
4444
v2 = Lib.diffusion_vertex()
4545
v3 = Lib.kuramoto_first()
4646

4747
# v2 and v3 are equal when it comes to the function!!
4848
vs = [v1,v2,v3,v2,v2,v1,v1,v3]
4949

50-
@test _find_identical(vs, eachindex(vs)) == [[1,6,7],[2,4,5],[3,8]]
50+
@test _find_identical_components(vs) == [[1,6,7],[2,4,5],[3,8]]
5151

5252
es = [Lib.diffusion_edge(),
5353
Lib.diffusion_edge_closure(),
5454
Lib.diffusion_edge_closure(),
5555
Lib.diffusion_edge_fid()]
56-
@test _find_identical(es, eachindex(es)) == [[1], [2], [3], [4]]
56+
@test _find_identical_components(es) == [[1], [2], [3], [4]]
5757
end
5858

5959
@testset "algin strings" begin

0 commit comments

Comments
 (0)