8989end 
9090
9191@inline  function  process_batches! (ex:: ThreadedExecution , fg, filt:: F , batches, inbufs, duopt) where  {F}
92-     #  unrolled_foreach(filt, batches) do batch
93-     #      (du, u, o, p, t) = duopt
94-     #      Threads.@threads for i in 1:length(batch)
95-     #          _type = dispatchT(batch)
96-     #          apply_comp!(_type, fg, batch, i, du, u, o, inbufs, p, t)
97-     #      end
98-     #  end
99-     #  return
100- 
10192    Nchunks =  Threads. nthreads ()
102-      #  Nchunks = 8 
93+ 
10394    #  chunking is kinda expensive, so we cache it
10495    key =  hash ((Base. objectid (batches), filt, fg, Nchunks))
10596    chunks =  get! (ex. chunk_cache, key) do 
10697        _chunk_batches (batches, filt, fg, Nchunks)
10798    end 
10899
109-     _eval_batchportion =  function  (batch, idxs)
110-         (du, u, o, p, t) =  duopt
111-         _type =  dispatchT (batch)
112-         for  i in  idxs
113-             apply_comp! (_type, fg, batch, i, du, u, o, inbufs, p, t)
100+     #  each chunk consists of array or tuple [(batch, idxs), ...]
101+     _eval_chunk =  function (chunk)
102+         unrolled_foreach (chunk) do  ch
103+             (; batch, idxs) =  ch
104+             (du, u, o, p, t) =  duopt
105+             _type =  dispatchT (batch)
106+             for  i in  idxs
107+                 apply_comp! (_type, fg, batch, i, du, u, o, inbufs, p, t)
108+             end 
114109        end 
115110    end 
116- 
117111    Threads. @sync  for  chunk in  chunks
118-         isempty (chunk) &&  continue 
119112        Threads. @spawn  begin 
120-             for  (; bi, idxs) in  chunk
121-                 batch =  batches[bi] #  filtering don in chunks
122-                 @noinline  _eval_batchportion (batch, idxs)
123-             end 
113+             @noinline  _eval_chunk (chunk)
124114        end 
125115    end 
126116end 
@@ -131,15 +121,15 @@ function _chunk_batches(batches, filt, fg, workers)
131121        Ncomp +=  length (batch):: Int 
132122        total_eqs +=  length (batch):: Int  *  _N_eqs (fg, batch):: Int 
133123    end 
134-     chunks =  Vector{Vector{ @NamedTuple {bi :: Int ,idxs :: UnitRange{Int64} }} }(undef, workers)
124+     chunks =  Vector {Any } (undef, workers)
135125
136126    eqs_per_worker =  total_eqs /  workers
137127    bi =  1 
138128    ci =  1 
139129    assigned =  0 
140130    eqs_assigned =  0 
141131    for  w in  1 : workers
142-         chunk =  @NamedTuple {bi :: Int ,idxs :: UnitRange{Int64} }[] 
132+         chunk =  Vector {Any} () 
143133        ci_start =  ci
144134        eqs_in_worker =  0 
145135        assigned_in_worker =  0 
@@ -170,15 +160,21 @@ function _chunk_batches(batches, filt, fg, workers)
170160                    ci +=  1 
171161                end 
172162                if  ci >  ci_start #  don't push empty chunks
173-                     push! (chunk, (; bi , idxs= ci_start: ci- 1 ))
163+                     push! (chunk, (; batch , idxs= ci_start: ci- 1 ))
174164                end 
175165                stop_collecting &&  break 
176166            end 
177167
178168            bi +=  1 
179169            ci =  1 
180170        end 
181-         chunks[w] =  chunk
171+ 
172+         #  narrow down type / make tuple
173+         chunks[w] =  if  length (chunk) <  10 
174+             Tuple (chunk)
175+         else 
176+             [c for  c in  chunk] #  narrow down type
177+         end 
182178
183179        #  update eqs per worker estimate for the other workders
184180        eqs_per_worker =  (total_eqs -  eqs_assigned) /  (workers -  w)
0 commit comments