Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit e653da5

Browse files
authored
Remove type parameters for SharedStatsStep and PooledStatsProcedure (#9)
1 parent 11ff4aa commit e653da5

File tree

3 files changed

+73
-65
lines changed

3 files changed

+73
-65
lines changed

src/StatsProcedures.jl

Lines changed: 53 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -163,42 +163,45 @@ function show(io::IO, ::MIME"text/plain", p::AbstractStatsProcedure{A,T}) where
163163
end
164164

165165
"""
166-
SharedStatsStep{T<:StatsStep, I}
166+
SharedStatsStep
167167
168168
A [`StatsStep`](@ref) that is possibly shared by
169169
multiple instances of procedures that are subtypes of [`AbstractStatsProcedure`](@ref).
170170
See also [`PooledStatsProcedure`](@ref).
171171
172-
# Parameters
173-
- `T<:StatsStep`: type of the only field `step`.
174-
- `I`: indices of the procedures that share this step.
172+
# Fields
173+
- `step::StatsStep`: the `step` that may be shared.
174+
- `ids::Vector{Int}`: indices of procedures that share `step`.
175175
"""
176-
struct SharedStatsStep{T<:StatsStep, I}
177-
step::T
178-
function SharedStatsStep(s::StatsStep, pid)
179-
pid = (unique!(sort!([pid...]))...,)
180-
return new{typeof(s), pid}(s)
176+
struct SharedStatsStep
177+
step::StatsStep
178+
ids::Vector{Int}
179+
function SharedStatsStep(s::StatsStep, ids)
180+
ids = unique!(sort!([ids...]))
181+
return new(s, ids)
181182
end
182183
end
183184

184-
_sharedby(::SharedStatsStep{T,I}) where {T,I} = I
185+
_sharedby(s::SharedStatsStep) = s.ids
185186
_f(s::SharedStatsStep) = _f(s.step)
186187
groupargs(s::SharedStatsStep, @nospecialize(ntargs::NamedTuple)) = groupargs(s.step, ntargs)
187188
combinedargs(s::SharedStatsStep, v::AbstractArray) = combinedargs(s.step, v)
188189

190+
==(x::SharedStatsStep, y::SharedStatsStep) =
191+
x.step == y.step && x.ids == y.ids
192+
189193
show(io::IO, s::SharedStatsStep) = print(io, s.step)
190194

191-
function show(io::IO, ::MIME"text/plain", s::SharedStatsStep{T,I}) where {T,I}
192-
nps = length(I)
195+
function show(io::IO, ::MIME"text/plain", s::SharedStatsStep)
196+
nps = length(s.ids)
193197
print(io, s.step, " (StatsStep shared by ", nps, " procedure")
194198
nps > 1 ? print(io, "s)") : print(io, ")")
195199
end
196200

197-
const SharedStatsSteps = NTuple{N, SharedStatsStep} where N
198201
const StatsProcedures = NTuple{N, AbstractStatsProcedure} where N
199202

200203
"""
201-
PooledStatsProcedure{P<:StatsProcedures, S<:SharedStatsSteps}
204+
PooledStatsProcedure
202205
203206
A collection of procedures and shared steps.
204207
@@ -207,36 +210,37 @@ in a way that helps avoid repeating identical steps.
207210
See also [`pool`](@ref).
208211
209212
# Fields
210-
- `procs::P`: a tuple of instances of subtypes of [`AbstractStatsProcedure`](@ref).
211-
- `steps::S`: a tuple of [`SharedStatsStep`](@ref) for the procedures in `procs`.
213+
- `procs::StatsProcedures`: a tuple of instances of subtypes of [`AbstractStatsProcedure`](@ref).
214+
- `steps::Vector{SharedStatsStep}`: sorted [`SharedStatsStep`](@ref)s.
212215
"""
213-
struct PooledStatsProcedure{P<:StatsProcedures, S<:SharedStatsSteps}
214-
procs::P
215-
steps::S
216+
struct PooledStatsProcedure
217+
procs::StatsProcedures
218+
steps::Vector{SharedStatsStep}
216219
end
217220

218-
function _sort(psteps::NTuple{N, Vector{SharedStatsStep}}) where N
221+
function _sort(psteps::Vector{Vector{SharedStatsStep}})
222+
N = length(psteps)
219223
sorted = SharedStatsStep[]
220224
state = [length(s) for s in psteps]
221-
pending = BitArray(state.>0)
225+
pending = state .> 0
222226
while any(pending)
223227
pid = (1:N)[pending]
224228
firsts = [psteps[i][end-state[i]+1] for i in pid]
225-
for i in 1:length(firsts)
226-
nshared = length(_sharedby(firsts[i]))
229+
for (i, fstep) in enumerate(firsts)
230+
nshared = length(_sharedby(fstep))
227231
if nshared == 1
228-
push!(sorted, firsts[i])
232+
push!(sorted, fstep)
229233
state[pid[i]] -= 1
230234
else
231-
shared = BitArray(s==firsts[i] for s in firsts)
235+
shared = firsts .== Ref(fstep)
232236
if sum(shared) == nshared
233-
push!(sorted, firsts[i])
237+
push!(sorted, fstep)
234238
state[pid[shared]] .-= 1
235239
break
236240
end
237241
end
238242
end
239-
pending = BitArray(state.>0)
243+
pending .= state .> 0
240244
end
241245
return sorted
242246
end
@@ -260,7 +264,7 @@ function pool(ps::AbstractStatsProcedure...)
260264
steps = union(ps...)
261265
N = sum(length.(ps))
262266
if length(steps) < N
263-
shared = ((Vector{SharedStatsStep}(undef, length(p)) for p in ps)...,)
267+
shared = [Vector{SharedStatsStep}(undef, length(p)) for p in ps]
264268
step_pos = Dict{StatsStep,Dict{Int64,Int64}}()
265269
for (i, p) in enumerate(ps)
266270
for n in 1:length(p)
@@ -298,33 +302,36 @@ function pool(ps::AbstractStatsProcedure...)
298302
end
299303
end
300304
end
301-
shared = (_sort(shared)...,)
305+
shared = _sort(shared)
302306
else
303-
shared = ((SharedStatsStep(s, i) for (i,p) in enumerate(ps) for s in p)...,)
307+
shared = [SharedStatsStep(s, i) for (i,p) in enumerate(ps) for s in p]
304308
end
305-
return PooledStatsProcedure{typeof(ps), typeof(shared)}(ps, shared)
309+
return PooledStatsProcedure(ps, shared)
306310
end
307311

308-
length(::PooledStatsProcedure{P,S}) where {P,S} = length(S.parameters)
309-
eltype(::Type{<:PooledStatsProcedure}) = SharedStatsStep
310-
firstindex(::PooledStatsProcedure{P,S}) where {P,S} = firstindex(S.parameters)
311-
lastindex(::PooledStatsProcedure{P,S}) where {P,S} = lastindex(S.parameters)
312+
length(p::PooledStatsProcedure) = length(p.steps)
313+
eltype(::Type{PooledStatsProcedure}) = SharedStatsStep
314+
firstindex(p::PooledStatsProcedure) = firstindex(p.steps)
315+
lastindex(p::PooledStatsProcedure) = lastindex(p.steps)
316+
317+
getindex(p::PooledStatsProcedure, i) = getindex(p.steps, i)
312318

313-
getindex(ps::PooledStatsProcedure, i) = getindex(ps.steps, i)
319+
iterate(p::PooledStatsProcedure, state=1) = iterate(p.steps, state)
314320

315-
iterate(ps::PooledStatsProcedure, state=1) = iterate(ps.steps, state)
321+
==(x::PooledStatsProcedure, y::PooledStatsProcedure) =
322+
x.procs == y.procs && x.steps == y.steps
316323

317-
show(io::IO, ps::PooledStatsProcedure) = print(io, typeof(ps).name.name)
324+
show(io::IO, p::PooledStatsProcedure) = print(io, typeof(p).name.name)
318325

319-
function show(io::IO, ::MIME"text/plain", ps::PooledStatsProcedure{P,S}) where {P,S}
320-
nstep = length(S.parameters)
321-
print(io, typeof(ps).name.name, " with ", nstep, " step")
326+
function show(io::IO, ::MIME"text/plain", p::PooledStatsProcedure)
327+
nstep = length(p.steps)
328+
print(io, typeof(p).name.name, " with ", nstep, " step")
322329
nstep > 1 ? print(io, "s ") : print(io, " ")
323-
nps = length(P.parameters)
330+
nps = length(p.procs)
324331
print(io, "from ", nps, " procedure")
325332
nps > 1 ? print(io, "s:") : print(io, ":")
326-
for p in P.parameters
327-
print(io, "\n ", p.parameters[1])
333+
for p in p.procs
334+
print(io, "\n ", typeof(p).parameters[1])
328335
end
329336
end
330337

@@ -462,8 +469,8 @@ function proceed(sps::AbstractVector{<:StatsSpec};
462469
ret, share = ret
463470
else
464471
fname = typeof(_f(step)).name.mt.name
465-
stepname = typeof(step).parameters[1].parameters[1]
466-
error("unexpected $(typeof(ret)) returned from $fname associated with StatsStep $stepname")
472+
stepname = typeof(step.step).parameters[1]
473+
error("unexpected type $(typeof(ret)) of object returned from $fname associated with StatsStep $stepname")
467474
end
468475
ntask += 1
469476
ntask_total += 1

src/did.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ for each corresponding row in `treatinds`.
231231
This method only selects estimates for treatment coefficients.
232232
Covariates are not taken into account.
233233
"""
234-
coef(f::Function, r::DIDResult) = view(r.coef, 1:length(r.treatinds))[f.(r.treatinds)]
234+
@inline coef(f::Function, r::DIDResult) =
235+
view(r.coef, 1:length(r.treatinds))[f.(r.treatinds)]
235236

236237
"""
237238
vcov(r::DIDResult)
@@ -275,7 +276,7 @@ for each corresponding row in `treatinds`.
275276
This method only selects estimates for treatment coefficients.
276277
Covariates are not taken into account.
277278
"""
278-
function vcov(f::Function, r::DIDResult)
279+
@inline function vcov(f::Function, r::DIDResult)
279280
N = length(r.treatinds)
280281
inds = f.(r.treatinds)
281282
return view(r.vcov, 1:N, 1:N)[inds, inds]

test/StatsProcedures.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using DiffinDiffsBase: _f, _get, groupargs,
2-
_sharedby, _show_args, _args_kwargs, _parse!, proceed
2+
_sharedby, _show_args, _args_kwargs, _parse!, pool, proceed
33
import DiffinDiffsBase: required, default, transformed, combinedargs
44

55
testvoidstep(a::String) = NamedTuple(), false
@@ -129,8 +129,8 @@ end
129129
@testset "SharedStatsStep" begin
130130
s1 = SharedStatsStep(TestRegStep(), 1)
131131
s2 = SharedStatsStep(TestRegStep(), [3,2])
132-
@test _sharedby(s1) == (1,)
133-
@test _sharedby(s2) == (2,3)
132+
@test _sharedby(s1) == [1]
133+
@test _sharedby(s2) == [2,3]
134134
@test _f(s1) == testregstep
135135
@test groupargs(s1, NamedTuple()) == ("a", "b")
136136

@@ -144,47 +144,47 @@ end
144144

145145
@testset "PooledStatsProcedure" begin
146146
ps = (rp,)
147-
shared = ((SharedStatsStep(s, 1) for s in rp)...,)
147+
shared = [SharedStatsStep(s, 1) for s in rp]
148148
p1 = pool(rp)
149-
@test p1 == PooledStatsProcedure{typeof(ps), typeof(shared)}(ps, shared)
149+
@test p1 == PooledStatsProcedure(ps, shared)
150150
@test length(p1) == 3
151151
@test eltype(PooledStatsProcedure) == SharedStatsStep
152152
@test firstindex(p1) == 1
153153
@test lastindex(p1) == 3
154154
@test p1[1] == SharedStatsStep(rp[1], 1)
155-
@test p1[1:3] == ((SharedStatsStep(s, 1) for s in rp)...,)
155+
@test p1[1:3] == [SharedStatsStep(s, 1) for s in rp]
156156
@test iterate(p1) == (shared[1], 2)
157157
@test iterate(p1, 2) == (shared[2], 3)
158158

159159
ps = (rp, rp)
160-
shared = ((SharedStatsStep(s, (1,2)) for s in rp)...,)
160+
shared = [SharedStatsStep(s, (1,2)) for s in rp]
161161
p2 = pool(rp, rp)
162-
@test p2 == PooledStatsProcedure{typeof(ps), typeof(shared)}(ps, shared)
162+
@test p2 == PooledStatsProcedure(ps, shared)
163163
@test length(p2) == 3
164164

165165
ps = (up, up)
166-
shared = ((SharedStatsStep(s, (1,2)) for s in up)...,)
166+
shared = [SharedStatsStep(s, (1,2)) for s in up]
167167
p3 = pool(up, up)
168-
@test p3 == PooledStatsProcedure{typeof(ps), typeof(shared)}(ps, shared)
168+
@test p3 == PooledStatsProcedure(ps, shared)
169169
@test length(p3) == 1
170170

171171
ps = (np,)
172-
shared = ()
172+
shared = []
173173
p4 = pool(np)
174-
@test p4 == PooledStatsProcedure{typeof(ps), typeof(shared)}(ps, shared)
174+
@test p4 == PooledStatsProcedure(ps, shared)
175175
@test length(p4) == 0
176176

177177
ps = (up, rp)
178-
shared = (SharedStatsStep(rp[1], 2), SharedStatsStep(rp[2], (1,2)), SharedStatsStep(rp[3], 2))
178+
shared = [SharedStatsStep(rp[1], 2), SharedStatsStep(rp[2], (1,2)), SharedStatsStep(rp[3], 2)]
179179
p5 = pool(up, rp)
180-
@test p5 == PooledStatsProcedure{typeof(ps), typeof(shared)}(ps, shared)
180+
@test p5 == PooledStatsProcedure(ps, shared)
181181
@test length(p5) == 3
182182

183183
ps = (rp, ip)
184-
shared = (SharedStatsStep(rp[1], 1), SharedStatsStep(ip[1], 2), SharedStatsStep(rp[2], 1),
185-
SharedStatsStep(ip[2], 2), SharedStatsStep(rp[3], (1,2)))
184+
shared = [SharedStatsStep(rp[1], 1), SharedStatsStep(ip[1], 2), SharedStatsStep(rp[2], 1),
185+
SharedStatsStep(ip[2], 2), SharedStatsStep(rp[3], (1,2))]
186186
p6 = pool(rp, ip)
187-
@test p6 == PooledStatsProcedure{typeof(ps), typeof(shared)}(ps, shared)
187+
@test p6 == PooledStatsProcedure(ps, shared)
188188
@test length(p6) == 5
189189

190190
@test sprint(show, p1) == "PooledStatsProcedure"

0 commit comments

Comments
 (0)