@@ -26,10 +26,9 @@ a message with the name of the `StatsStep` is printed to `stdout`.
2626"""
2727struct StatsStep{Alias,F<: Function ,SpecNames,TraceNames} end
2828
29- _alias (step:: StatsStep{A} ) where A = A
30- _f (step:: StatsStep{A,F} ) where {A,F} = F. instance
31- _specnames (step:: StatsStep{A,F,S} ) where {A,F,S} = S
32- _tracenames (step:: StatsStep{A,F,S,T} ) where {A,F,S,T} = T
29+ _f (:: StatsStep{A,F} ) where {A,F} = F. instance
30+ _specnames (:: StatsStep{A,F,S} ) where {A,F,S} = S
31+ _tracenames (:: StatsStep{A,F,S,T} ) where {A,F,S,T} = T
3332
3433function (step:: StatsStep{A,F,S,T} )(ntargs:: NamedTuple ; verbose:: Bool = false ) where {A,F,S,T}
3534 verbose || (haskey (ntargs, :verbose ) && ntargs. verbose) &&
@@ -45,10 +44,10 @@ function (step::StatsStep{A,F,S,T})(ntargs::NamedTuple; verbose::Bool=false) whe
4544 end
4645end
4746
48- show (io:: IO , s :: StatsStep{A} ) where A = print (io, A)
47+ show (io:: IO , :: StatsStep{A} ) where A = print (io, A)
4948
5049function show (io:: IO , :: MIME"text/plain" , s:: StatsStep{A,F,S,T} ) where {A,F,S,T}
51- print (io, A, " (" , typeof (s). name, " that calls " )
50+ print (io, A, " (" , typeof (s). name. name , " that calls " )
5251 fmod = F. name. mt. module
5352 fmod == Main ? print (io, F. name. mt. name) : print (io, fmod, " ." , F. name. mt. name)
5453 println (io, " ):" )
@@ -71,26 +70,26 @@ all subtypes of `AbstractStatsProcedure`.
7170"""
7271abstract type AbstractStatsProcedure{A,T<: NTuple{N,StatsStep} where N} end
7372
74- length (p :: AbstractStatsProcedure{A,T} ) where {A,T} = length (T. parameters)
73+ length (:: AbstractStatsProcedure{A,T} ) where {A,T} = length (T. parameters)
7574eltype (:: Type{<:AbstractStatsProcedure} ) = StatsStep
76- firstindex (p :: AbstractStatsProcedure{A,T} ) where {A,T} = firstindex (T. parameters)
77- lastindex (p :: AbstractStatsProcedure{A,T} ) where {A,T} = lastindex (T. parameters)
75+ firstindex (:: AbstractStatsProcedure{A,T} ) where {A,T} = firstindex (T. parameters)
76+ lastindex (:: AbstractStatsProcedure{A,T} ) where {A,T} = lastindex (T. parameters)
7877
7978function getindex (p:: AbstractStatsProcedure{A,T} , i) where {A,T}
8079 fs = T. parameters[i]
8180 return fs isa Type && fs <: StatsStep ? fs. instance : [f. instance for f in fs]
8281end
8382
84- getindex (p :: AbstractStatsProcedure{A,T} , i:: Int ) where {A,T} = T. parameters[i]. instance
83+ getindex (:: AbstractStatsProcedure{A,T} , i:: Int ) where {A,T} = T. parameters[i]. instance
8584
8685iterate (p:: AbstractStatsProcedure , state= 1 ) =
8786 state > length (p) ? nothing : (p[state], state+ 1 )
8887
89- show (io:: IO , p :: AbstractStatsProcedure{A} ) where A = print (io, A)
88+ show (io:: IO , :: AbstractStatsProcedure{A} ) where A = print (io, A)
9089
9190function show (io:: IO , :: MIME"text/plain" , p:: AbstractStatsProcedure{A,T} ) where {A,T}
9291 nstep = length (p)
93- print (io, A, " (" , typeof (p). name, " with " , nstep, " step" )
92+ print (io, A, " (" , typeof (p). name. name , " with " , nstep, " step" )
9493 if nstep > 0
9594 nstep > 1 ? print (io, " s):\n " ) : print (io, " ):\n " )
9695 for (i, step) in enumerate (p)
@@ -103,68 +102,91 @@ function show(io::IO, ::MIME"text/plain", p::AbstractStatsProcedure{A,T}) where
103102end
104103
105104"""
106- SharedStatsStep{T<:StatsStep,PID }
105+ SharedStatsStep{T<:StatsStep,I }
107106
108107A [`StatsStep`](@ref) that is possibly shared by
109108multiple instances of procedures that are subtypes of [`AbstractStatsProcedure`](@ref).
110109See also [`PooledStatsProcedure`](@ref).
111110
112111# Parameters
113112- `T<:StatsStep`: type of the only field `step`.
114- - `PID `: indices of the procedures that share this step.
113+ - `I `: indices of the procedures that share this step.
115114"""
116- struct SharedStatsStep{T<: StatsStep ,PID }
115+ struct SharedStatsStep{T<: StatsStep ,I }
117116 step:: T
117+ function SharedStatsStep (s:: StatsStep , pid)
118+ pid = (unique! (sort! ([pid... ]))... ,)
119+ return new {typeof(s), pid} (s)
120+ end
118121end
119122
120- == (x:: SharedStatsStep{T,PID1} , y:: SharedStatsStep{T,PID2} ) where {T,PID1,PID2} =
121- x. step == y. step && Set (PID1) == Set (PID2)
122-
123- _share (s:: T , pid) where {T<: StatsStep } = SharedStatsStep {T,(Int.(pid)...,)} (s)
124- _sharedby (s:: SharedStatsStep{T,PID} ) where {T,PID} = PID
123+ _sharedby (:: SharedStatsStep{T,I} ) where {T,I} = I
125124_f (s:: SharedStatsStep ) = _f (s. step)
126125_specnames (s:: SharedStatsStep ) = _specnames (s. step)
127126_tracenames (s:: SharedStatsStep ) = _tracenames (s. step)
128127
129128show (io:: IO , s:: SharedStatsStep ) = print (io, s. step)
130129
131- function show (io:: IO , :: MIME"text/plain" , s:: SharedStatsStep{T,PID } ) where {T,PID }
132- nps = length (PID )
130+ function show (io:: IO , :: MIME"text/plain" , s:: SharedStatsStep{T,I } ) where {T,I }
131+ nps = length (I )
133132 print (io, s. step, " (StatsStep shared by " , nps, " procedure" )
134133 nps > 1 ? print (io, " s)" ) : print (io, " )" )
135134end
136135
137- const SharedStatsSteps = NTuple{N, Vector{ SharedStatsStep} } where N
136+ const SharedStatsSteps = NTuple{N, SharedStatsStep} where N
138137const StatsProcedures = NTuple{N, AbstractStatsProcedure} where N
139138
140139"""
141- PooledStatsProcedure{P<:StatsProcedures,S<:SharedStatsSteps,N }
140+ PooledStatsProcedure{P<:StatsProcedures,S<:SharedStatsSteps}
142141
143142A collection of procedures and shared steps.
144143
145- An instance of `PooledStatsProcedure` is iterable among the shared steps
144+ An instance of `PooledStatsProcedure` is indexed and iterable among the shared steps
146145in a way that helps avoid repeating identical steps.
147146See also [`pool`](@ref).
148147
149148# Fields
150149- `procs::P`: a tuple of instances of subtypes of [`AbstractStatsProcedure`](@ref).
151- - `steps::S`: a tuple of vectors [`SharedStatsStep`](@ref) for each procedure in `procs`.
150+ - `steps::S`: a tuple of [`SharedStatsStep`](@ref) for the procedures in `procs`.
152151"""
153- struct PooledStatsProcedure{P<: StatsProcedures ,S<: SharedStatsSteps ,N }
152+ struct PooledStatsProcedure{P<: StatsProcedures ,S<: SharedStatsSteps }
154153 procs:: P
155154 steps:: S
156155end
157156
158- == (x:: PooledStatsProcedure{P,S,N} , y:: PooledStatsProcedure{P,S,N} ) where {P,S,N} =
159- x. procs == y. procs && x. steps == y. steps
157+ function _sort (psteps:: NTuple{N, Vector{SharedStatsStep}} ) where N
158+ sorted = SharedStatsStep[]
159+ state = [length (s) for s in psteps]
160+ pending = BitArray (state.> 0 )
161+ while any (pending)
162+ pid = (1 : N)[pending]
163+ firsts = [psteps[i][end - state[i]+ 1 ] for i in pid]
164+ for i in 1 : length (firsts)
165+ nshared = length (_sharedby (firsts[i]))
166+ if nshared == 1
167+ push! (sorted, firsts[i])
168+ state[pid[i]] -= 1
169+ else
170+ shared = BitArray (s== firsts[i] for s in firsts)
171+ if sum (shared) == nshared
172+ push! (sorted, firsts[i])
173+ state[pid[shared]] .- = 1
174+ break
175+ end
176+ end
177+ end
178+ pending = BitArray (state.> 0 )
179+ end
180+ return sorted
181+ end
160182
161183"""
162184 pool(ps::AbstractStatsProcedure...)
163185
164186Construct a [`PooledStatsProcedure`](@ref) by determining
165187how each [`StatsStep`](@ref) is shared among several procedures in `ps`.
166188
167- It might not be safe to share the same [`StatsStep`](@ref) in different procedures
189+ It is unsafe to share the same [`StatsStep`](@ref) in different procedures
168190due to the relative position of this step to the other common steps
169191among these procedures.
170192The fallback method implemented for a collection of [`AbstractStatsProcedure`](@ref)
@@ -174,79 +196,69 @@ are not compatible between a pair of procedures.
174196function pool (ps:: AbstractStatsProcedure... )
175197 ps = (ps... ,)
176198 nps = length (ps)
177- shared = ((Vector {SharedStatsStep} (undef, length (p)) for p in ps). .. ,)
178- for (pid, p) in enumerate (ps)
179- shared[pid] .= _share .(collect (p), pid)
180- end
181- steps = union (collect (p) for p in ps)
199+ steps = union (ps... )
182200 N = sum (length .(ps))
183201 if length (steps) < N
184- step_loc = Dict {StatsStep,Dict{Int64,Int64}} ()
202+ shared = ((Vector {SharedStatsStep} (undef, length (p)) for p in ps). .. ,)
203+ step_pos = Dict {StatsStep,Dict{Int64,Int64}} ()
185204 for (i, p) in enumerate (ps)
186205 for n in 1 : length (p)
187- if haskey (step_loc , p[n])
188- step_loc [p[n]][i] = n
206+ if haskey (step_pos , p[n])
207+ step_pos [p[n]][i] = n
189208 else
190- step_loc [p[n]] = Dict (i=> n)
209+ step_pos [p[n]] = Dict (i=> n)
191210 end
192211 end
193212 end
194- for (step, loc) in step_loc
195- if length (loc) == 1
196- continue
213+ for (step, pos) in step_pos
214+ if length (pos) == 1
215+ kv = collect (pos)[1 ]
216+ shared[kv[1 ]][kv[2 ]] = SharedStatsStep (step, kv[1 ])
197217 else
198- shared_pid = collect (keys (loc ))
218+ shared_pid = collect (keys (pos ))
199219 for c in combinations (shared_pid, 2 )
200220 csteps = intersect (ps[c[1 ]], ps[c[2 ]])
201- rank1 = findfirst (x-> x== step, sort! ([step_loc[s][c[1 ]] for s in csteps]))
202- rank2 = findfirst (x-> x== step, sort! ([step_loc[s][c[2 ]] for s in csteps]))
221+ pos1 = sort (csteps, by= x-> step_pos[x][c[1 ]])
222+ pos2 = sort (csteps, by= x-> step_pos[x][c[2 ]])
223+ rank1 = findfirst (x-> x== step, pos1)
224+ rank2 = findfirst (x-> x== step, pos2)
203225 if rank1 != rank2
204- setdiff (shared_pid, c)
226+ setdiff! (shared_pid, c)
227+ shared[c[1 ]][pos[c[1 ]]] = SharedStatsStep (step, c[1 ])
228+ shared[c[2 ]][pos[c[2 ]]] = SharedStatsStep (step, c[2 ])
205229 length (shared_pid) <= 1 && break
206230 end
207231 end
208- if length (shared_pid) >= 2
232+ if length (shared_pid) > 0
209233 N = N - length (shared_pid) + 1
210- for s in shared_pid
211- shared[s][loc[s ]] = _share (step, shared_pid)
234+ for p in shared_pid
235+ shared[p][pos[p ]] = SharedStatsStep (step, shared_pid)
212236 end
213237 end
214238 end
215239 end
240+ shared = (_sort (shared)... ,)
241+ else
242+ shared = ((SharedStatsStep (s, i) for (i,p) in enumerate (ps) for s in p). .. ,)
216243 end
217- return PooledStatsProcedure {typeof(ps), typeof(shared), N } (ps, shared)
244+ return PooledStatsProcedure {typeof(ps), typeof(shared)} (ps, shared)
218245end
219246
220- length (ps :: PooledStatsProcedure{P,S,N } ) where {P,S,N } = N
247+ length (:: PooledStatsProcedure{P,S} ) where {P,S} = length (S . parameters)
221248eltype (:: Type{<:PooledStatsProcedure} ) = SharedStatsStep
249+ firstindex (:: PooledStatsProcedure{P,S} ) where {P,S} = firstindex (S. parameters)
250+ lastindex (:: PooledStatsProcedure{P,S} ) where {P,S} = lastindex (S. parameters)
222251
223- function iterate (ps:: PooledStatsProcedure , state= deepcopy (ps. steps))
224- state = state[BitArray (length .(state).> 0 )]
225- length (state) > 0 || return nothing
226- firsts = first .(state)
227- for i in 1 : length (firsts)
228- nshared = length (_sharedby (firsts[i]))
229- if nshared == 1
230- deleteat! (state[i],1 )
231- return (firsts[i], state)
232- else
233- shared = BitArray (s== firsts[i] for s in firsts)
234- if sum (shared) == nshared
235- for p in state[shared]
236- deleteat! (p, 1 )
237- end
238- return (firsts[i], state)
239- end
240- end
241- end
242- error (" bad construction of $(typeof (ps)) " )
243- end
252+ getindex (ps:: PooledStatsProcedure , i) = getindex (ps. steps, i)
244253
245- show (io :: IO , ps:: PooledStatsProcedure ) = print (io, typeof (ps) . name )
254+ iterate ( ps:: PooledStatsProcedure , state = 1 ) = iterate (ps . steps, state )
246255
247- function show (io:: IO , :: MIME"text/plain" , ps:: PooledStatsProcedure{P,S,N} ) where {P,S,N}
248- print (io, typeof (ps). name, " with " , N, " step" )
249- N > 1 ? print (io, " s " ) : print (io, " " )
256+ show (io:: IO , ps:: PooledStatsProcedure ) = print (io, typeof (ps). name. name)
257+
258+ function show (io:: IO , :: MIME"text/plain" , ps:: PooledStatsProcedure{P,S} ) where {P,S}
259+ nstep = length (S. parameters)
260+ print (io, typeof (ps). name. name, " with " , nstep, " step" )
261+ nstep > 1 ? print (io, " s " ) : print (io, " " )
250262 nps = length (P. parameters)
251263 print (io, " from " , nps, " procedure" )
252264 nps > 1 ? print (io, " s:" ) : print (io, " :" )
@@ -308,8 +320,6 @@ while ignoring the orders.
308320≊ (x:: StatsSpec{A1,T} , y:: StatsSpec{A2,T} ) where {A1,A2,T} =
309321 x. args ≊ y. args
310322
311- _procedure (sp:: StatsSpec{A,T} ) where {A,T} = T
312-
313323function (sp:: StatsSpec{A,T} )(;
314324 verbose:: Bool = false , keep= nothing , keepall:: Bool = false ) where {A,T}
315325 args = verbose ? merge (sp. args, (verbose= true ,)) : sp. args
@@ -333,12 +343,12 @@ function (sp::StatsSpec{A,T})(;
333343 end
334344end
335345
336- show (io:: IO , sp :: StatsSpec{A,T } ) where {A,T } = print (io, A== Symbol (" " ) ? " unnamed" : A)
346+ show (io:: IO , :: StatsSpec{A} ) where {A} = print (io, A== Symbol (" " ) ? " unnamed" : A)
337347
338- _show_args (io :: IO , sp :: StatsSpec ) = nothing
348+ _show_args (:: IO , :: StatsSpec ) = nothing
339349
340350function show (io:: IO , :: MIME"text/plain" , sp:: StatsSpec{A,T} ) where {A,T}
341- print (io, A== Symbol (" " ) ? " unnamed" : A, " (" , typeof (sp). name,
351+ print (io, A== Symbol (" " ) ? " unnamed" : A, " (" , typeof (sp). name. name ,
342352 " for " , T. parameters[1 ], " )" )
343353 _show_args (io, sp)
344354end
0 commit comments