9999 # return
100100
101101 Nchunks = Threads. nthreads ()
102- # Nchunks = 4
102+ # Nchunks = 8
103103 # chunking is kinda expensive, so we cache it
104104 key = hash ((Base. objectid (batches), filt, fg, Nchunks))
105105 chunks = get! (ex. chunk_cache, key) do
106106 _chunk_batches (batches, filt, fg, Nchunks)
107107 end
108108
109- _progress_in_batch = function (batch, ci, processed, N )
109+ _eval_batchportion = function (batch, idxs )
110110 (du, u, o, p, t) = duopt
111111 _type = dispatchT (batch)
112- while ci ≤ length (batch) && processed < N
113- apply_comp! (_type, fg, batch, ci, du, u, o, inbufs, p, t)
114- ci += 1
115- processed += 1
112+ for i in idxs
113+ apply_comp! (_type, fg, batch, i, du, u, o, inbufs, p, t)
116114 end
117- processed, ci
118115 end
119116
120117 Threads. @sync for chunk in chunks
121- chunk. N == 0 && continue
118+ isempty ( chunk) && continue
122119 Threads. @spawn begin
123- local N = chunk. N
124- local bi = chunk. batch_i
125- local ci = chunk. comp_i
126- local processed = 0
127- while processed < N
128- batch = batches[bi]
129- filt (batch) || continue
130- processed, ci = @noinline _progress_in_batch (batch, ci, processed, N)
131- bi += 1
132- ci = 1
120+ for (; bi, idxs) in chunk
121+ batch = batches[bi] # filtering don in chunks
122+ @noinline _eval_batchportion (batch, idxs)
133123 end
134124 end
135125 end
@@ -141,51 +131,53 @@ function _chunk_batches(batches, filt, fg, workers)
141131 Ncomp += length (batch):: Int
142132 total_eqs += length (batch):: Int * _N_eqs (fg, batch):: Int
143133 end
144- chunks = Vector {@NamedTuple{batch_i ::Int, comp_i::Int, N::Int }} (undef, workers)
134+ chunks = Vector{Vector{ @NamedTuple {bi :: Int ,idxs :: UnitRange{Int64} } }}(undef, workers)
145135
146136 eqs_per_worker = total_eqs / workers
147137 bi = 1
148138 ci = 1
149139 assigned = 0
150140 eqs_assigned = 0
151141 for w in 1 : workers
142+ chunk = @NamedTuple {bi:: Int ,idxs:: UnitRange{Int64} }[]
152143 ci_start = ci
153- bi_start = bi
154144 eqs_in_worker = 0
155145 assigned_in_worker = 0
156146 while assigned < Ncomp
157147 batch = batches[bi]
158- filt (batch) || continue
159148
160- Neqs = _N_eqs (fg, batch)
161- stop_collecting = false
162- while true
163- if ci > length (batch)
164- break
149+ if filt (batch) # only process if the batch is not filtered out
150+ Neqs = _N_eqs (fg, batch)
151+ stop_collecting = false
152+ while true
153+ if ci > length (batch)
154+ break
155+ end
156+
157+ # compare, whether adding the new component helps to come closer to eqs_per_worker
158+ diff_now = abs (eqs_in_worker - eqs_per_worker)
159+ diff_next = abs (eqs_in_worker + Neqs - eqs_per_worker)
160+ stop_collecting = assigned == Ncomp || diff_now ≤ diff_next
161+ if stop_collecting
162+ break
163+ end
164+
165+ # add component to worker
166+ eqs_assigned += Neqs
167+ eqs_in_worker += Neqs
168+ assigned_in_worker += 1
169+ assigned += 1
170+ ci += 1
165171 end
166-
167- # compare, whether adding the new component helps to come closer to eqs_per_worker
168- diff_now = abs (eqs_in_worker - eqs_per_worker)
169- diff_next = abs (eqs_in_worker + Neqs - eqs_per_worker)
170- stop_collecting = assigned == Ncomp || diff_now ≤ diff_next
171- if stop_collecting
172- break
172+ if ci > ci_start # don't push empty chunks
173+ push! (chunk, (; bi, idxs= ci_start: ci- 1 ))
173174 end
174-
175- # add component to worker
176- eqs_assigned += Neqs
177- eqs_in_worker += Neqs
178- assigned_in_worker += 1
179- assigned += 1
180- ci += 1
175+ stop_collecting && break
181176 end
182- # if the hard stop collection is reached, break, otherwise jump to next batch and continue
183- stop_collecting && break
184177
185178 bi += 1
186179 ci = 1
187180 end
188- chunk = (; batch_i= bi_start, comp_i= ci_start, N= assigned_in_worker)
189181 chunks[w] = chunk
190182
191183 # update eqs per worker estimate for the other workders
0 commit comments