9999    #  return
100100
101101    Nchunks =  Threads. nthreads ()
102-     #  Nchunks = 4 
102+     #  Nchunks = 8 
103103    #  chunking is kinda expensive, so we cache it
104104    key =  hash ((Base. objectid (batches), filt, fg, Nchunks))
105105    chunks =  get! (ex. chunk_cache, key) do 
106106        _chunk_batches (batches, filt, fg, Nchunks)
107107    end 
108108
109-     _progress_in_batch  =  function (batch, ci, processed, N )
109+     _eval_batchportion  =  function   (batch, idxs )
110110        (du, u, o, p, t) =  duopt
111111        _type =  dispatchT (batch)
112-         while  ci ≤  length (batch) &&  processed <  N
113-             apply_comp! (_type, fg, batch, ci, du, u, o, inbufs, p, t)
114-             ci +=  1 
115-             processed +=  1 
112+         for  i in  idxs
113+             apply_comp! (_type, fg, batch, i, du, u, o, inbufs, p, t)
116114        end 
117-         processed, ci
118115    end 
119116
120117    Threads. @sync  for  chunk in  chunks
121-         chunk. N  ==   0  &&  continue 
118+         isempty ( chunk)  &&  continue 
122119        Threads. @spawn  begin 
123-             local  N =  chunk. N
124-             local  bi =  chunk. batch_i
125-             local  ci =  chunk. comp_i
126-             local  processed =  0 
127-             while  processed <  N
128-                 batch =  batches[bi]
129-                 filt (batch) ||  continue 
130-                 processed, ci =  @noinline  _progress_in_batch (batch, ci, processed, N)
131-                 bi +=  1 
132-                 ci =  1 
120+             for  (; bi, idxs) in  chunk
121+                 batch =  batches[bi] #  filtering don in chunks
122+                 @noinline  _eval_batchportion (batch, idxs)
133123            end 
134124        end 
135125    end 
@@ -141,51 +131,53 @@ function _chunk_batches(batches, filt, fg, workers)
141131        Ncomp +=  length (batch):: Int 
142132        total_eqs +=  length (batch):: Int  *  _N_eqs (fg, batch):: Int 
143133    end 
144-     chunks =  Vector {@NamedTuple{batch_i ::Int, comp_i::Int, N::Int }} (undef, workers)
134+     chunks =  Vector{Vector{ @NamedTuple {bi :: Int ,idxs :: UnitRange{Int64} } }}(undef, workers)
145135
146136    eqs_per_worker =  total_eqs /  workers
147137    bi =  1 
148138    ci =  1 
149139    assigned =  0 
150140    eqs_assigned =  0 
151141    for  w in  1 : workers
142+         chunk =  @NamedTuple {bi:: Int ,idxs:: UnitRange{Int64} }[]
152143        ci_start =  ci
153-         bi_start =  bi
154144        eqs_in_worker =  0 
155145        assigned_in_worker =  0 
156146        while  assigned <  Ncomp
157147            batch =  batches[bi]
158-             filt (batch) ||  continue 
159148
160-             Neqs =  _N_eqs (fg, batch)
161-             stop_collecting =  false 
162-             while  true 
163-                 if  ci >  length (batch)
164-                     break 
149+             if  filt (batch) # only process if the batch is not filtered out
150+                 Neqs =  _N_eqs (fg, batch)
151+                 stop_collecting =  false 
152+                 while  true 
153+                     if  ci >  length (batch)
154+                         break 
155+                     end 
156+ 
157+                     #  compare, whether adding the new component helps to come closer to eqs_per_worker
158+                     diff_now  =  abs (eqs_in_worker -  eqs_per_worker)
159+                     diff_next =  abs (eqs_in_worker +  Neqs -  eqs_per_worker)
160+                     stop_collecting =  assigned ==  Ncomp ||  diff_now ≤  diff_next
161+                     if  stop_collecting
162+                         break 
163+                     end 
164+ 
165+                     #  add component to worker
166+                     eqs_assigned +=  Neqs
167+                     eqs_in_worker +=  Neqs
168+                     assigned_in_worker +=  1 
169+                     assigned +=  1 
170+                     ci +=  1 
165171                end 
166- 
167-                 #  compare, whether adding the new component helps to come closer to eqs_per_worker
168-                 diff_now  =  abs (eqs_in_worker -  eqs_per_worker)
169-                 diff_next =  abs (eqs_in_worker +  Neqs -  eqs_per_worker)
170-                 stop_collecting =  assigned ==  Ncomp ||  diff_now ≤  diff_next
171-                 if  stop_collecting
172-                     break 
172+                 if  ci >  ci_start #  don't push empty chunks
173+                     push! (chunk, (; bi, idxs= ci_start: ci- 1 ))
173174                end 
174- 
175-                 #  add component to worker
176-                 eqs_assigned +=  Neqs
177-                 eqs_in_worker +=  Neqs
178-                 assigned_in_worker +=  1 
179-                 assigned +=  1 
180-                 ci +=  1 
175+                 stop_collecting &&  break 
181176            end 
182-             #  if the hard stop collection is reached, break, otherwise jump to next batch and continue
183-             stop_collecting &&  break 
184177
185178            bi +=  1 
186179            ci =  1 
187180        end 
188-         chunk =  (; batch_i= bi_start, comp_i= ci_start, N= assigned_in_worker)
189181        chunks[w] =  chunk
190182
191183        #  update eqs per worker estimate for the other workders
0 commit comments