@@ -163,42 +163,45 @@ function show(io::IO, ::MIME"text/plain", p::AbstractStatsProcedure{A,T}) where
163163end
164164
165165"""
166- SharedStatsStep{T<:StatsStep, I}
166+ SharedStatsStep
167167
168168A [`StatsStep`](@ref) that is possibly shared by
169169multiple instances of procedures that are subtypes of [`AbstractStatsProcedure`](@ref).
170170See also [`PooledStatsProcedure`](@ref).
171171
172- # Parameters
173- - `T<: StatsStep`: type of the only field `step`.
174- - `I `: indices of the procedures that share this step.
172+ # Fields
173+ - `step:: StatsStep`: the `step` that may be shared .
174+ - `ids::Vector{Int} `: indices of procedures that share ` step` .
175175"""
176- struct SharedStatsStep{T<: StatsStep , I}
177- step:: T
178- function SharedStatsStep (s:: StatsStep , pid)
179- pid = (unique! (sort! ([pid... ]))... ,)
180- return new {typeof(s), pid} (s)
176+ struct SharedStatsStep
177+ step:: StatsStep
178+ ids:: Vector{Int}
179+ function SharedStatsStep (s:: StatsStep , ids)
180+ ids = unique! (sort! ([ids... ]))
181+ return new (s, ids)
181182 end
182183end
183184
184- _sharedby (:: SharedStatsStep{T,I} ) where {T,I} = I
185+ _sharedby (s :: SharedStatsStep ) = s . ids
185186_f (s:: SharedStatsStep ) = _f (s. step)
186187groupargs (s:: SharedStatsStep , @nospecialize (ntargs:: NamedTuple )) = groupargs (s. step, ntargs)
187188combinedargs (s:: SharedStatsStep , v:: AbstractArray ) = combinedargs (s. step, v)
188189
190+ == (x:: SharedStatsStep , y:: SharedStatsStep ) =
191+ x. step == y. step && x. ids == y. ids
192+
189193show (io:: IO , s:: SharedStatsStep ) = print (io, s. step)
190194
191- function show (io:: IO , :: MIME"text/plain" , s:: SharedStatsStep{T,I} ) where {T,I}
192- nps = length (I )
195+ function show (io:: IO , :: MIME"text/plain" , s:: SharedStatsStep )
196+ nps = length (s . ids )
193197 print (io, s. step, " (StatsStep shared by " , nps, " procedure" )
194198 nps > 1 ? print (io, " s)" ) : print (io, " )" )
195199end
196200
197- const SharedStatsSteps = NTuple{N, SharedStatsStep} where N
198201const StatsProcedures = NTuple{N, AbstractStatsProcedure} where N
199202
200203"""
201- PooledStatsProcedure{P<:StatsProcedures, S<:SharedStatsSteps}
204+ PooledStatsProcedure
202205
203206A collection of procedures and shared steps.
204207
@@ -207,36 +210,37 @@ in a way that helps avoid repeating identical steps.
207210See also [`pool`](@ref).
208211
209212# Fields
210- - `procs::P `: a tuple of instances of subtypes of [`AbstractStatsProcedure`](@ref).
211- - `steps::S `: a tuple of [`SharedStatsStep`](@ref) for the procedures in `procs` .
213+ - `procs::StatsProcedures `: a tuple of instances of subtypes of [`AbstractStatsProcedure`](@ref).
214+ - `steps::Vector{SharedStatsStep} `: sorted [`SharedStatsStep`](@ref)s .
212215"""
213- struct PooledStatsProcedure{P <: StatsProcedures , S <: SharedStatsSteps }
214- procs:: P
215- steps:: S
216+ struct PooledStatsProcedure
217+ procs:: StatsProcedures
218+ steps:: Vector{SharedStatsStep}
216219end
217220
218- function _sort (psteps:: NTuple{N, Vector{SharedStatsStep}} ) where N
221+ function _sort (psteps:: Vector{Vector{SharedStatsStep}} )
222+ N = length (psteps)
219223 sorted = SharedStatsStep[]
220224 state = [length (s) for s in psteps]
221- pending = BitArray ( state.> 0 )
225+ pending = state .> 0
222226 while any (pending)
223227 pid = (1 : N)[pending]
224228 firsts = [psteps[i][end - state[i]+ 1 ] for i in pid]
225- for i in 1 : length (firsts)
226- nshared = length (_sharedby (firsts[i] ))
229+ for (i, fstep) in enumerate (firsts)
230+ nshared = length (_sharedby (fstep ))
227231 if nshared == 1
228- push! (sorted, firsts[i] )
232+ push! (sorted, fstep )
229233 state[pid[i]] -= 1
230234 else
231- shared = BitArray (s == firsts[i] for s in firsts )
235+ shared = firsts .== Ref (fstep )
232236 if sum (shared) == nshared
233- push! (sorted, firsts[i] )
237+ push! (sorted, fstep )
234238 state[pid[shared]] .- = 1
235239 break
236240 end
237241 end
238242 end
239- pending = BitArray ( state.> 0 )
243+ pending . = state .> 0
240244 end
241245 return sorted
242246end
@@ -260,7 +264,7 @@ function pool(ps::AbstractStatsProcedure...)
260264 steps = union (ps... )
261265 N = sum (length .(ps))
262266 if length (steps) < N
263- shared = (( Vector {SharedStatsStep} (undef, length (p)) for p in ps) . .. ,)
267+ shared = [ Vector {SharedStatsStep} (undef, length (p)) for p in ps]
264268 step_pos = Dict {StatsStep,Dict{Int64,Int64}} ()
265269 for (i, p) in enumerate (ps)
266270 for n in 1 : length (p)
@@ -298,33 +302,36 @@ function pool(ps::AbstractStatsProcedure...)
298302 end
299303 end
300304 end
301- shared = ( _sort (shared) ... , )
305+ shared = _sort (shared)
302306 else
303- shared = (( SharedStatsStep (s, i) for (i,p) in enumerate (ps) for s in p) . .. ,)
307+ shared = [ SharedStatsStep (s, i) for (i,p) in enumerate (ps) for s in p]
304308 end
305- return PooledStatsProcedure {typeof(ps), typeof(shared)} (ps, shared)
309+ return PooledStatsProcedure (ps, shared)
306310end
307311
308- length (:: PooledStatsProcedure{P,S} ) where {P,S} = length (S. parameters)
309- eltype (:: Type{<:PooledStatsProcedure} ) = SharedStatsStep
310- firstindex (:: PooledStatsProcedure{P,S} ) where {P,S} = firstindex (S. parameters)
311- lastindex (:: PooledStatsProcedure{P,S} ) where {P,S} = lastindex (S. parameters)
312+ length (p:: PooledStatsProcedure ) = length (p. steps)
313+ eltype (:: Type{PooledStatsProcedure} ) = SharedStatsStep
314+ firstindex (p:: PooledStatsProcedure ) = firstindex (p. steps)
315+ lastindex (p:: PooledStatsProcedure ) = lastindex (p. steps)
316+
317+ getindex (p:: PooledStatsProcedure , i) = getindex (p. steps, i)
312318
313- getindex (ps :: PooledStatsProcedure , i ) = getindex (ps . steps, i )
319+ iterate (p :: PooledStatsProcedure , state = 1 ) = iterate (p . steps, state )
314320
315- iterate (ps:: PooledStatsProcedure , state= 1 ) = iterate (ps. steps, state)
321+ == (x:: PooledStatsProcedure , y:: PooledStatsProcedure ) =
322+ x. procs == y. procs && x. steps == y. steps
316323
317- show (io:: IO , ps :: PooledStatsProcedure ) = print (io, typeof (ps ). name. name)
324+ show (io:: IO , p :: PooledStatsProcedure ) = print (io, typeof (p ). name. name)
318325
319- function show (io:: IO , :: MIME"text/plain" , ps :: PooledStatsProcedure{P,S} ) where {P,S}
320- nstep = length (S . parameters )
321- print (io, typeof (ps ). name. name, " with " , nstep, " step" )
326+ function show (io:: IO , :: MIME"text/plain" , p :: PooledStatsProcedure )
327+ nstep = length (p . steps )
328+ print (io, typeof (p ). name. name, " with " , nstep, " step" )
322329 nstep > 1 ? print (io, " s " ) : print (io, " " )
323- nps = length (P . parameters )
330+ nps = length (p . procs )
324331 print (io, " from " , nps, " procedure" )
325332 nps > 1 ? print (io, " s:" ) : print (io, " :" )
326- for p in P . parameters
327- print (io, " \n " , p . parameters[1 ])
333+ for p in p . procs
334+ print (io, " \n " , typeof (p) . parameters[1 ])
328335 end
329336end
330337
@@ -462,8 +469,8 @@ function proceed(sps::AbstractVector{<:StatsSpec};
462469 ret, share = ret
463470 else
464471 fname = typeof (_f (step)). name. mt. name
465- stepname = typeof (step) . parameters[ 1 ] . parameters[1 ]
466- error (" unexpected $(typeof (ret)) returned from $fname associated with StatsStep $stepname " )
472+ stepname = typeof (step. step) . parameters[1 ]
473+ error (" unexpected type $(typeof (ret)) of object returned from $fname associated with StatsStep $stepname " )
467474 end
468475 ntask += 1
469476 ntask_total += 1
0 commit comments