diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index f2f49c80c..b1d7dce21 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -34,12 +34,10 @@ function Adapt.adapt_structure(to, n::Network) caches = (;output = _adapt_diffcache(to, n.caches.output), aggregation = _adapt_diffcache(to, n.caches.aggregation), external = _adapt_diffcache(to, n.caches.external)) - exT = typeof(executionstyle(n)) - gT = typeof(n.im.g) + ex = executionstyle(n) extmap = adapt(to, n.extmap) - Network{exT,gT,typeof(layer),typeof(vb),typeof(mm),eltype(caches),typeof(gbp),typeof(extmap)}( - vb, layer, n.im, caches, mm, gbp, extmap) + Network(vb, layer, n.im, caches, mm, gbp, extmap, ex) end Adapt.@adapt_structure NetworkLayer diff --git a/src/construction.jl b/src/construction.jl index f2e3e7f92..0f914f145 100644 --- a/src/construction.jl +++ b/src/construction.jl @@ -183,14 +183,14 @@ function Network(g::AbstractGraph, # create map for extenral inputs extmap = has_external_input(im) ? ExtMap(im) : nothing - nw = Network{typeof(execution),typeof(g),typeof(nl), typeof(vertexbatches), - typeof(mass_matrix),eltype(caches),typeof(gbufprovider),typeof(extmap)}( + nw = Network( vertexbatches, nl, im, caches, mass_matrix, gbufprovider, extmap, + execution, ) end diff --git a/src/coreloop.jl b/src/coreloop.jl index 1162c1520..c6ccb8d15 100644 --- a/src/coreloop.jl +++ b/src/coreloop.jl @@ -88,15 +88,112 @@ end end end -@inline function process_batches!(::ThreadedExecution, fg, filt::F, batches, inbufs, duopt) where {F} - unrolled_foreach(filt, batches) do batch - (du, u, o, p, t) = duopt - Threads.@threads for i in 1:length(batch) +@inline function process_batches!(ex::ThreadedExecution, fg, filt::F, batches, inbufs, duopt) where {F} + Nchunks = Threads.nthreads() + + # chunking is kinda expensive, so we cache it + key = hash((Base.objectid(batches), filt, fg, Nchunks)) + chunks = get!(ex.chunk_cache, key) do + _chunk_batches(batches, filt, fg, Nchunks) + end + + # each chunk consists of array or tuple [(batch, idxs), ...] + _eval_chunk = function(chunk) + unrolled_foreach(chunk) do ch + (; batch, idxs) = ch + (du, u, o, p, t) = duopt _type = dispatchT(batch) - apply_comp!(_type, fg, batch, i, du, u, o, inbufs, p, t) + for i in idxs + apply_comp!(_type, fg, batch, i, du, u, o, inbufs, p, t) + end + end + end + Threads.@sync for chunk in chunks + Threads.@spawn begin + @noinline _eval_chunk(chunk) + end + end +end +function _chunk_batches(batches, filt, fg, workers) + Ncomp = 0 + total_eqs = 0 + unrolled_foreach(filt, batches) do batch + Ncomp += length(batch)::Int + total_eqs += length(batch)::Int * _N_eqs(fg, batch)::Int + end + chunks = Vector{Any}(undef, workers) + + eqs_per_worker = total_eqs / workers + # println("\nTotal eqs: $total_eqs in $Ncomp components, eqs per worker: $eqs_per_worker ($fg)") + bi = 1 + ci = 1 + assigned = 0 + eqs_assigned = 0 + for w in 1:workers + # println("Assign worker $w: goal: $eqs_per_worker") + chunk = Vector{Any}() + eqs_in_worker = 0 + assigned_in_worker = 0 + while assigned < Ncomp + batch = batches[bi] + + if filt(batch) #only process if the batch is not filtered out + ci_start = ci + Neqs = _N_eqs(fg, batch) + stop_collecting = false + while true + if ci > length(batch) + break + end + + # compare, whether adding the new component helps to come closer to eqs_per_worker + diff_now = abs(eqs_in_worker - eqs_per_worker) + diff_next = abs(eqs_in_worker + Neqs - eqs_per_worker) + stop_collecting = assigned == Ncomp || diff_now < diff_next + if stop_collecting + break + end + + # add component to worker + # println(" - Assign component $ci ($Neqs eqs)") + eqs_assigned += Neqs + eqs_in_worker += Neqs + assigned_in_worker += 1 + assigned += 1 + ci += 1 + end + if ci > ci_start # don't push empty chunks + # println(" - Assign batch $(bi) -> $(ci_start:(ci-1)) $(length(ci_start:(ci-1))*Neqs) eqs)") + push!(chunk, (; batch, idxs=ci_start:(ci-1))) + else + # println(" - Skip empty batch $(bi) -> $(ci_start:(ci-1))") + end + stop_collecting && break + else + # println(" - Skip batch $(bi)") + end + + bi += 1 + ci = 1 + end + + # narrow down type / make tuple + chunks[w] = if length(chunk) < 10 + Tuple(chunk) + else + [c for c in chunk] # narrow down type end + + # update eqs per worker estimate for the other workders + eqs_per_worker = (total_eqs - eqs_assigned) / (workers - w) end + @assert assigned == Ncomp + return chunks end +_N_eqs(::Val{:f}, batch) = Int(dim(batch)) +_N_eqs(::Val{:g}, batch) = Int(outdim(batch)) +_N_eqs(::Val{:fg}, batch) = Int(dim(batch)) + Int(outdim(batch)) + @inline function process_batches!(::PolyesterExecution, fg, filt::F, batches, inbufs, duopt) where {F} unrolled_foreach(filt, batches) do batch diff --git a/src/executionstyles.jl b/src/executionstyles.jl index 3fc46a17c..523be9dcb 100644 --- a/src/executionstyles.jl +++ b/src/executionstyles.jl @@ -41,7 +41,9 @@ struct PolyesterExecution{buffered} <: ExecutionStyle{buffered} end Parallel execution using Julia threads. For `buffered` see [`ExecutionStyle`](@ref). """ -struct ThreadedExecution{buffered} <: ExecutionStyle{buffered} end +@kwdef struct ThreadedExecution{buffered} <: ExecutionStyle{buffered} + chunk_cache::Dict{UInt, Vector} = Dict{UInt, Vector}() +end usebuffer(::ExecutionStyle{buffered}) where {buffered} = buffered usebuffer(::Type{<:ExecutionStyle{buffered}}) where {buffered} = buffered diff --git a/src/network_structure.jl b/src/network_structure.jl index 5910f6f37..eabae40c0 100644 --- a/src/network_structure.jl +++ b/src/network_structure.jl @@ -83,8 +83,10 @@ struct Network{EX<:ExecutionStyle,G,NL,VTup,MM,CT,GBT,EM} gbufprovider::GBT "map to gather external inputs" extmap::EM + "execution style" + executionstyle::EX end -executionstyle(::Network{ex}) where {ex} = ex() +executionstyle(nw::Network) = nw.executionstyle nvbatches(::Network) = length(vertexbatches) """ @@ -164,6 +166,8 @@ end @inline compf(b::ComponentBatch) = b.compf @inline compg(b::ComponentBatch) = b.compg @inline fftype(b::ComponentBatch) = b.ff +@inline dim(b::ComponentBatch) = sum(b.statestride.strides) +@inline outdim(b::ComponentBatch) = sum(b.outbufstride.strides) @inline pdim(b::ComponentBatch) = b.pstride.strides @inline extdim(b::ComponentBatch) = b.extbufstride.strides