Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit 760cc2e

Browse files
committed
Improve PooledStatsProcedure and pool
1 parent b67a83a commit 760cc2e

File tree

2 files changed

+142
-133
lines changed

2 files changed

+142
-133
lines changed

src/StatsProcedures.jl

Lines changed: 91 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@ a message with the name of the `StatsStep` is printed to `stdout`.
2626
"""
2727
struct StatsStep{Alias,F<:Function,SpecNames,TraceNames} end
2828

29-
_alias(step::StatsStep{A}) where A = A
30-
_f(step::StatsStep{A,F}) where {A,F} = F.instance
31-
_specnames(step::StatsStep{A,F,S}) where {A,F,S} = S
32-
_tracenames(step::StatsStep{A,F,S,T}) where {A,F,S,T} = T
29+
_f(::StatsStep{A,F}) where {A,F} = F.instance
30+
_specnames(::StatsStep{A,F,S}) where {A,F,S} = S
31+
_tracenames(::StatsStep{A,F,S,T}) where {A,F,S,T} = T
3332

3433
function (step::StatsStep{A,F,S,T})(ntargs::NamedTuple; verbose::Bool=false) where {A,F,S,T}
3534
verbose || (haskey(ntargs, :verbose) && ntargs.verbose) &&
@@ -45,10 +44,10 @@ function (step::StatsStep{A,F,S,T})(ntargs::NamedTuple; verbose::Bool=false) whe
4544
end
4645
end
4746

48-
show(io::IO, s::StatsStep{A}) where A = print(io, A)
47+
show(io::IO, ::StatsStep{A}) where A = print(io, A)
4948

5049
function show(io::IO, ::MIME"text/plain", s::StatsStep{A,F,S,T}) where {A,F,S,T}
51-
print(io, A, " (", typeof(s).name, " that calls ")
50+
print(io, A, " (", typeof(s).name.name, " that calls ")
5251
fmod = F.name.mt.module
5352
fmod == Main ? print(io, F.name.mt.name) : print(io, fmod, ".", F.name.mt.name)
5453
println(io, "):")
@@ -71,26 +70,26 @@ all subtypes of `AbstractStatsProcedure`.
7170
"""
7271
abstract type AbstractStatsProcedure{A,T<:NTuple{N,StatsStep} where N} end
7372

74-
length(p::AbstractStatsProcedure{A,T}) where {A,T} = length(T.parameters)
73+
length(::AbstractStatsProcedure{A,T}) where {A,T} = length(T.parameters)
7574
eltype(::Type{<:AbstractStatsProcedure}) = StatsStep
76-
firstindex(p::AbstractStatsProcedure{A,T}) where {A,T} = firstindex(T.parameters)
77-
lastindex(p::AbstractStatsProcedure{A,T}) where {A,T} = lastindex(T.parameters)
75+
firstindex(::AbstractStatsProcedure{A,T}) where {A,T} = firstindex(T.parameters)
76+
lastindex(::AbstractStatsProcedure{A,T}) where {A,T} = lastindex(T.parameters)
7877

7978
function getindex(p::AbstractStatsProcedure{A,T}, i) where {A,T}
8079
fs = T.parameters[i]
8180
return fs isa Type && fs <: StatsStep ? fs.instance : [f.instance for f in fs]
8281
end
8382

84-
getindex(p::AbstractStatsProcedure{A,T}, i::Int) where {A,T} = T.parameters[i].instance
83+
getindex(::AbstractStatsProcedure{A,T}, i::Int) where {A,T} = T.parameters[i].instance
8584

8685
iterate(p::AbstractStatsProcedure, state=1) =
8786
state > length(p) ? nothing : (p[state], state+1)
8887

89-
show(io::IO, p::AbstractStatsProcedure{A}) where A = print(io, A)
88+
show(io::IO, ::AbstractStatsProcedure{A}) where A = print(io, A)
9089

9190
function show(io::IO, ::MIME"text/plain", p::AbstractStatsProcedure{A,T}) where {A,T}
9291
nstep = length(p)
93-
print(io, A, " (", typeof(p).name, " with ", nstep, " step")
92+
print(io, A, " (", typeof(p).name.name, " with ", nstep, " step")
9493
if nstep > 0
9594
nstep > 1 ? print(io, "s):\n ") : print(io, "):\n ")
9695
for (i, step) in enumerate(p)
@@ -103,68 +102,91 @@ function show(io::IO, ::MIME"text/plain", p::AbstractStatsProcedure{A,T}) where
103102
end
104103

105104
"""
106-
SharedStatsStep{T<:StatsStep,PID}
105+
SharedStatsStep{T<:StatsStep,I}
107106
108107
A [`StatsStep`](@ref) that is possibly shared by
109108
multiple instances of procedures that are subtypes of [`AbstractStatsProcedure`](@ref).
110109
See also [`PooledStatsProcedure`](@ref).
111110
112111
# Parameters
113112
- `T<:StatsStep`: type of the only field `step`.
114-
- `PID`: indices of the procedures that share this step.
113+
- `I`: indices of the procedures that share this step.
115114
"""
116-
struct SharedStatsStep{T<:StatsStep,PID}
115+
struct SharedStatsStep{T<:StatsStep,I}
117116
step::T
117+
function SharedStatsStep(s::StatsStep, pid)
118+
pid = (unique!(sort!([pid...]))...,)
119+
return new{typeof(s), pid}(s)
120+
end
118121
end
119122

120-
==(x::SharedStatsStep{T,PID1}, y::SharedStatsStep{T,PID2}) where {T,PID1,PID2} =
121-
x.step == y.step && Set(PID1) == Set(PID2)
122-
123-
_share(s::T, pid) where {T<:StatsStep} = SharedStatsStep{T,(Int.(pid)...,)}(s)
124-
_sharedby(s::SharedStatsStep{T,PID}) where {T,PID} = PID
123+
_sharedby(::SharedStatsStep{T,I}) where {T,I} = I
125124
_f(s::SharedStatsStep) = _f(s.step)
126125
_specnames(s::SharedStatsStep) = _specnames(s.step)
127126
_tracenames(s::SharedStatsStep) = _tracenames(s.step)
128127

129128
show(io::IO, s::SharedStatsStep) = print(io, s.step)
130129

131-
function show(io::IO, ::MIME"text/plain", s::SharedStatsStep{T,PID}) where {T,PID}
132-
nps = length(PID)
130+
function show(io::IO, ::MIME"text/plain", s::SharedStatsStep{T,I}) where {T,I}
131+
nps = length(I)
133132
print(io, s.step, " (StatsStep shared by ", nps, " procedure")
134133
nps > 1 ? print(io, "s)") : print(io, ")")
135134
end
136135

137-
const SharedStatsSteps = NTuple{N, Vector{SharedStatsStep}} where N
136+
const SharedStatsSteps = NTuple{N, SharedStatsStep} where N
138137
const StatsProcedures = NTuple{N, AbstractStatsProcedure} where N
139138

140139
"""
141-
PooledStatsProcedure{P<:StatsProcedures,S<:SharedStatsSteps,N}
140+
PooledStatsProcedure{P<:StatsProcedures,S<:SharedStatsSteps}
142141
143142
A collection of procedures and shared steps.
144143
145-
An instance of `PooledStatsProcedure` is iterable among the shared steps
144+
An instance of `PooledStatsProcedure` is indexed and iterable among the shared steps
146145
in a way that helps avoid repeating identical steps.
147146
See also [`pool`](@ref).
148147
149148
# Fields
150149
- `procs::P`: a tuple of instances of subtypes of [`AbstractStatsProcedure`](@ref).
151-
- `steps::S`: a tuple of vectors [`SharedStatsStep`](@ref) for each procedure in `procs`.
150+
- `steps::S`: a tuple of [`SharedStatsStep`](@ref) for the procedures in `procs`.
152151
"""
153-
struct PooledStatsProcedure{P<:StatsProcedures,S<:SharedStatsSteps,N}
152+
struct PooledStatsProcedure{P<:StatsProcedures,S<:SharedStatsSteps}
154153
procs::P
155154
steps::S
156155
end
157156

158-
==(x::PooledStatsProcedure{P,S,N}, y::PooledStatsProcedure{P,S,N}) where {P,S,N} =
159-
x.procs == y.procs && x.steps == y.steps
157+
function _sort(psteps::NTuple{N, Vector{SharedStatsStep}}) where N
158+
sorted = SharedStatsStep[]
159+
state = [length(s) for s in psteps]
160+
pending = BitArray(state.>0)
161+
while any(pending)
162+
pid = (1:N)[pending]
163+
firsts = [psteps[i][end-state[i]+1] for i in pid]
164+
for i in 1:length(firsts)
165+
nshared = length(_sharedby(firsts[i]))
166+
if nshared == 1
167+
push!(sorted, firsts[i])
168+
state[pid[i]] -= 1
169+
else
170+
shared = BitArray(s==firsts[i] for s in firsts)
171+
if sum(shared) == nshared
172+
push!(sorted, firsts[i])
173+
state[pid[shared]] .-= 1
174+
break
175+
end
176+
end
177+
end
178+
pending = BitArray(state.>0)
179+
end
180+
return sorted
181+
end
160182

161183
"""
162184
pool(ps::AbstractStatsProcedure...)
163185
164186
Construct a [`PooledStatsProcedure`](@ref) by determining
165187
how each [`StatsStep`](@ref) is shared among several procedures in `ps`.
166188
167-
It might not be safe to share the same [`StatsStep`](@ref) in different procedures
189+
It is unsafe to share the same [`StatsStep`](@ref) in different procedures
168190
due to the relative position of this step to the other common steps
169191
among these procedures.
170192
The fallback method implemented for a collection of [`AbstractStatsProcedure`](@ref)
@@ -174,79 +196,69 @@ are not compatible between a pair of procedures.
174196
function pool(ps::AbstractStatsProcedure...)
175197
ps = (ps...,)
176198
nps = length(ps)
177-
shared = ((Vector{SharedStatsStep}(undef, length(p)) for p in ps)...,)
178-
for (pid, p) in enumerate(ps)
179-
shared[pid] .= _share.(collect(p), pid)
180-
end
181-
steps = union(collect(p) for p in ps)
199+
steps = union(ps...)
182200
N = sum(length.(ps))
183201
if length(steps) < N
184-
step_loc = Dict{StatsStep,Dict{Int64,Int64}}()
202+
shared = ((Vector{SharedStatsStep}(undef, length(p)) for p in ps)...,)
203+
step_pos = Dict{StatsStep,Dict{Int64,Int64}}()
185204
for (i, p) in enumerate(ps)
186205
for n in 1:length(p)
187-
if haskey(step_loc, p[n])
188-
step_loc[p[n]][i] = n
206+
if haskey(step_pos, p[n])
207+
step_pos[p[n]][i] = n
189208
else
190-
step_loc[p[n]] = Dict(i=>n)
209+
step_pos[p[n]] = Dict(i=>n)
191210
end
192211
end
193212
end
194-
for (step, loc) in step_loc
195-
if length(loc) == 1
196-
continue
213+
for (step, pos) in step_pos
214+
if length(pos) == 1
215+
kv = collect(pos)[1]
216+
shared[kv[1]][kv[2]] = SharedStatsStep(step, kv[1])
197217
else
198-
shared_pid = collect(keys(loc))
218+
shared_pid = collect(keys(pos))
199219
for c in combinations(shared_pid, 2)
200220
csteps = intersect(ps[c[1]], ps[c[2]])
201-
rank1 = findfirst(x->x==step, sort!([step_loc[s][c[1]] for s in csteps]))
202-
rank2 = findfirst(x->x==step, sort!([step_loc[s][c[2]] for s in csteps]))
221+
pos1 = sort(csteps, by=x->step_pos[x][c[1]])
222+
pos2 = sort(csteps, by=x->step_pos[x][c[2]])
223+
rank1 = findfirst(x->x==step, pos1)
224+
rank2 = findfirst(x->x==step, pos2)
203225
if rank1 != rank2
204-
setdiff(shared_pid, c)
226+
setdiff!(shared_pid, c)
227+
shared[c[1]][pos[c[1]]] = SharedStatsStep(step, c[1])
228+
shared[c[2]][pos[c[2]]] = SharedStatsStep(step, c[2])
205229
length(shared_pid) <= 1 && break
206230
end
207231
end
208-
if length(shared_pid) >= 2
232+
if length(shared_pid) > 0
209233
N = N - length(shared_pid) + 1
210-
for s in shared_pid
211-
shared[s][loc[s]] = _share(step, shared_pid)
234+
for p in shared_pid
235+
shared[p][pos[p]] = SharedStatsStep(step, shared_pid)
212236
end
213237
end
214238
end
215239
end
240+
shared = (_sort(shared)...,)
241+
else
242+
shared = ((SharedStatsStep(s, i) for (i,p) in enumerate(ps) for s in p)...,)
216243
end
217-
return PooledStatsProcedure{typeof(ps), typeof(shared), N}(ps, shared)
244+
return PooledStatsProcedure{typeof(ps), typeof(shared)}(ps, shared)
218245
end
219246

220-
length(ps::PooledStatsProcedure{P,S,N}) where {P,S,N} = N
247+
length(::PooledStatsProcedure{P,S}) where {P,S} = length(S.parameters)
221248
eltype(::Type{<:PooledStatsProcedure}) = SharedStatsStep
249+
firstindex(::PooledStatsProcedure{P,S}) where {P,S} = firstindex(S.parameters)
250+
lastindex(::PooledStatsProcedure{P,S}) where {P,S} = lastindex(S.parameters)
222251

223-
function iterate(ps::PooledStatsProcedure, state=deepcopy(ps.steps))
224-
state = state[BitArray(length.(state).>0)]
225-
length(state) > 0 || return nothing
226-
firsts = first.(state)
227-
for i in 1:length(firsts)
228-
nshared = length(_sharedby(firsts[i]))
229-
if nshared == 1
230-
deleteat!(state[i],1)
231-
return (firsts[i], state)
232-
else
233-
shared = BitArray(s==firsts[i] for s in firsts)
234-
if sum(shared) == nshared
235-
for p in state[shared]
236-
deleteat!(p, 1)
237-
end
238-
return (firsts[i], state)
239-
end
240-
end
241-
end
242-
error("bad construction of $(typeof(ps))")
243-
end
252+
getindex(ps::PooledStatsProcedure, i) = getindex(ps.steps, i)
244253

245-
show(io::IO, ps::PooledStatsProcedure) = print(io, typeof(ps).name)
254+
iterate(ps::PooledStatsProcedure, state=1) = iterate(ps.steps, state)
246255

247-
function show(io::IO, ::MIME"text/plain", ps::PooledStatsProcedure{P,S,N}) where {P,S,N}
248-
print(io, typeof(ps).name, " with ", N, " step")
249-
N > 1 ? print(io, "s ") : print(io, " ")
256+
show(io::IO, ps::PooledStatsProcedure) = print(io, typeof(ps).name.name)
257+
258+
function show(io::IO, ::MIME"text/plain", ps::PooledStatsProcedure{P,S}) where {P,S}
259+
nstep = length(S.parameters)
260+
print(io, typeof(ps).name.name, " with ", nstep, " step")
261+
nstep > 1 ? print(io, "s ") : print(io, " ")
250262
nps = length(P.parameters)
251263
print(io, "from ", nps, " procedure")
252264
nps > 1 ? print(io, "s:") : print(io, ":")
@@ -308,8 +320,6 @@ while ignoring the orders.
308320
(x::StatsSpec{A1,T}, y::StatsSpec{A2,T}) where {A1,A2,T} =
309321
x.args y.args
310322

311-
_procedure(sp::StatsSpec{A,T}) where {A,T} = T
312-
313323
function (sp::StatsSpec{A,T})(;
314324
verbose::Bool=false, keep=nothing, keepall::Bool=false) where {A,T}
315325
args = verbose ? merge(sp.args, (verbose=true,)) : sp.args
@@ -333,12 +343,12 @@ function (sp::StatsSpec{A,T})(;
333343
end
334344
end
335345

336-
show(io::IO, sp::StatsSpec{A,T}) where {A,T} = print(io, A==Symbol("") ? "unnamed" : A)
346+
show(io::IO, ::StatsSpec{A}) where {A} = print(io, A==Symbol("") ? "unnamed" : A)
337347

338-
_show_args(io::IO, sp::StatsSpec) = nothing
348+
_show_args(::IO, ::StatsSpec) = nothing
339349

340350
function show(io::IO, ::MIME"text/plain", sp::StatsSpec{A,T}) where {A,T}
341-
print(io, A==Symbol("") ? "unnamed" : A, " (", typeof(sp).name,
351+
print(io, A==Symbol("") ? "unnamed" : A, " (", typeof(sp).name.name,
342352
" for ", T.parameters[1], ")")
343353
_show_args(io, sp)
344354
end

0 commit comments

Comments
 (0)