Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit 6af12d3

Browse files
committed
Improve StatsStep and add CheckVars
1 parent 3fafa42 commit 6af12d3

File tree

7 files changed

+242
-103
lines changed

7 files changed

+242
-103
lines changed

src/DiffinDiffsBase.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ export @fieldequal,
4848
nevertreated,
4949
NotYetTreatedParallel,
5050
notyettreated,
51+
treated,
5152

5253
TreatmentTerm,
5354
treat,
@@ -65,6 +66,7 @@ export @fieldequal,
6566
proceed,
6667

6768
CheckData,
69+
CheckVars,
6870

6971
DiffinDiffsEstimator,
7072
DefaultDID,

src/StatsProcedures.jl

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
StatsStep{Alias,F<:Function}
2+
StatsStep{Alias, F<:Function}
33
44
Specify the function for moving a step in an [`AbstractStatsProcedure`](@ref).
55
An instance of `StatsStep` is callable.
@@ -12,7 +12,7 @@ An instance of `StatsStep` is callable.
1212
(step::StatsStep{A,F})(ntargs::NamedTuple; verbose::Bool=false)
1313
1414
Call an instance of function of type `F` with arguments
15-
formed by updating `NamedArgs` with `ntargs`.
15+
formed by updating `NamedTuple` returned by `[`namedargs(step)`](@ref)` with `ntargs`.
1616
1717
A message with the name of the `StatsStep` is printed to `stdout`
1818
if a keyword `verbose` takes the value `true`
@@ -21,20 +21,30 @@ The value from `ntargs` supersedes the keyword argument
2121
in case both are specified.
2222
2323
## Returns
24-
- `NamedTuple`: named intermidiate results.
24+
- `NamedTuple`: named intermediate results.
2525
"""
26-
struct StatsStep{Alias,F<:Function} end
26+
struct StatsStep{Alias, F<:Function} end
2727

2828
_f(::StatsStep{A,F}) where {A,F} = F.instance
2929

30+
"""
31+
namedargs(s::StatsStep)
32+
33+
Return a `NamedTuple` with keys showing the names of arguments
34+
accepted by `s` and values representing the defaults.
35+
"""
36+
namedargs(s::StatsStep) = error("method for $(typeof(s)) is not defined")
37+
3038
_getargs(ntargs::NamedTuple, s::StatsStep) = _update(ntargs, namedargs(s))
3139
_update(a::NamedTuple{N1}, b::NamedTuple{N2}) where {N1,N2} =
3240
NamedTuple{N2}(map(n->getfield(sym_in(n, N1) ? a : b, n), N2))
3341

42+
_combinedargs(::StatsStep, ::Any) = ()
43+
3444
function (step::StatsStep{A,F})(ntargs::NamedTuple; verbose::Bool=false) where {A,F}
3545
haskey(ntargs, :verbose) && (verbose = ntargs.verbose)
3646
verbose && printstyled("Running ", step, "\n", color=:green)
37-
ret = F.instance(_getargs(ntargs, step)...)
47+
ret, share = F.instance(_getargs(ntargs, step)..., _combinedargs(step, (ntargs,))...)
3848
if ret isa NamedTuple
3949
return merge(ntargs, ret)
4050
elseif ret === nothing
@@ -56,15 +66,7 @@ function show(io::IO, ::MIME"text/plain", s::StatsStep{A,F}) where {A,F}
5666
end
5767

5868
"""
59-
namedargs(s::StatsStep)
60-
61-
Return a `NamedTuple` with keys showing the names of arguments
62-
accepted by `s` and values representing the defaults.
63-
"""
64-
namedargs(s::StatsStep) = error("method for $(typeof(s)) is not defined")
65-
66-
"""
67-
AbstractStatsProcedure{Alias,T<:NTuple{N,StatsStep} where N}
69+
AbstractStatsProcedure{Alias, T<:NTuple{N,StatsStep} where N}
6870
6971
Supertype for all types specifying the procedure for statistical estimation or inference.
7072
@@ -75,7 +77,7 @@ all subtypes of `AbstractStatsProcedure`.
7577
- `Alias::Symbol`: alias of the type for pretty-printing.
7678
- `T<:NTuple{N,StatsStep}`: steps involved in the procedure.
7779
"""
78-
abstract type AbstractStatsProcedure{A,T<:NTuple{N,StatsStep} where N} end
80+
abstract type AbstractStatsProcedure{Alias, T<:NTuple{N,StatsStep} where N} end
7981

8082
length(::AbstractStatsProcedure{A,T}) where {A,T} = length(T.parameters)
8183
eltype(::Type{<:AbstractStatsProcedure}) = StatsStep
@@ -109,7 +111,7 @@ function show(io::IO, ::MIME"text/plain", p::AbstractStatsProcedure{A,T}) where
109111
end
110112

111113
"""
112-
SharedStatsStep{T<:StatsStep,I}
114+
SharedStatsStep{T<:StatsStep, I}
113115
114116
A [`StatsStep`](@ref) that is possibly shared by
115117
multiple instances of procedures that are subtypes of [`AbstractStatsProcedure`](@ref).
@@ -119,7 +121,7 @@ See also [`PooledStatsProcedure`](@ref).
119121
- `T<:StatsStep`: type of the only field `step`.
120122
- `I`: indices of the procedures that share this step.
121123
"""
122-
struct SharedStatsStep{T<:StatsStep,I}
124+
struct SharedStatsStep{T<:StatsStep, I}
123125
step::T
124126
function SharedStatsStep(s::StatsStep, pid)
125127
pid = (unique!(sort!([pid...]))...,)
@@ -130,6 +132,7 @@ end
130132
_sharedby(::SharedStatsStep{T,I}) where {T,I} = I
131133
_f(s::SharedStatsStep) = _f(s.step)
132134
_getargs(ntargs::NamedTuple, s::SharedStatsStep) = _getargs(ntargs, s.step)
135+
_combinedargs(s::SharedStatsStep, v::AbstractArray) = _combinedargs(s.step, v)
133136

134137
show(io::IO, s::SharedStatsStep) = print(io, s.step)
135138

@@ -143,7 +146,7 @@ const SharedStatsSteps = NTuple{N, SharedStatsStep} where N
143146
const StatsProcedures = NTuple{N, AbstractStatsProcedure} where N
144147

145148
"""
146-
PooledStatsProcedure{P<:StatsProcedures,S<:SharedStatsSteps}
149+
PooledStatsProcedure{P<:StatsProcedures, S<:SharedStatsSteps}
147150
148151
A collection of procedures and shared steps.
149152
@@ -155,7 +158,7 @@ See also [`pool`](@ref).
155158
- `procs::P`: a tuple of instances of subtypes of [`AbstractStatsProcedure`](@ref).
156159
- `steps::S`: a tuple of [`SharedStatsStep`](@ref) for the procedures in `procs`.
157160
"""
158-
struct PooledStatsProcedure{P<:StatsProcedures,S<:SharedStatsSteps}
161+
struct PooledStatsProcedure{P<:StatsProcedures, S<:SharedStatsSteps}
159162
procs::P
160163
steps::S
161164
end
@@ -274,7 +277,7 @@ function show(io::IO, ::MIME"text/plain", ps::PooledStatsProcedure{P,S}) where {
274277
end
275278

276279
"""
277-
StatsSpec{Alias,T<:AbstractStatsProcedure}
280+
StatsSpec{Alias, T<:AbstractStatsProcedure}
278281
279282
Record the specification for a statistical procedure of type `T`.
280283
@@ -297,7 +300,7 @@ or the last value returned by the last [`StatsStep`](@ref) is returned.
297300
- `keep=nothing`: names (of type `Symbol`) of additional objects to be returned.
298301
- `keepall::Bool=false`: return all objects returned by each step.
299302
"""
300-
struct StatsSpec{Alias,T<:AbstractStatsProcedure}
303+
struct StatsSpec{Alias, T<:AbstractStatsProcedure}
301304
args::NamedTuple
302305
StatsSpec(name::Union{Symbol,String},
303306
T::Type{<:AbstractStatsProcedure}, args::NamedTuple) =
@@ -391,7 +394,7 @@ For end users, `Macro`s that generate `Expr`s for these function calls should be
391394
392395
Optional default arguments are merged
393396
with the arguments provided for each individual specification
394-
and replace the default values specified for each procedure.
397+
and supersede the default values specified for each procedure through [`namedargs`](@ref).
395398
These default arguments should be specified in the same pattern as
396399
how arguments are specified for each specification inside the code block,
397400
as `@specset` processes these arguments by calling
@@ -470,12 +473,18 @@ function proceed(sps::AbstractVector{<:StatsSpec};
470473
taskids = vcat((gids[steps.procs[i]] for i in _sharedby(step))...)
471474
tasks = groupview(r->_getargs(r, step), view(traces, taskids))
472475
for (ins, subtb) in pairs(tasks)
473-
ret = _f(step)(ins...)
476+
ret, share = _f(step)(ins..., _combinedargs(step, subtb)...)
474477
ntask += 1
475478
ntask_total += 1
476479
if ret !== nothing
477-
for i in eachindex(subtb)
478-
subtb[i] = merge(subtb[i], deepcopy(ret))
480+
if share
481+
for i in eachindex(subtb)
482+
subtb[i] = merge(subtb[i], ret)
483+
end
484+
else
485+
for i in eachindex(subtb)
486+
subtb[i] = merge(subtb[i], deepcopy(ret))
487+
end
479488
end
480489
end
481490
end

src/parallels.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ assume a parallel trends assumption holds over all the relevant time periods.
8585
abstract type TrendParallel{C,S} <: AbstractParallel{C,S} end
8686

8787
"""
88-
NeverTreatedParallel{T<:Integer,C,S} <: TrendParallel{C,S}
88+
NeverTreatedParallel{C,S,T<:Integer} <: TrendParallel{C,S}
8989
9090
Assume a parallel trends assumption holds between any group
9191
that received the treatment during the sample periods
@@ -97,15 +97,17 @@ See also [`nevertreated`](@ref).
9797
- `c::C`: a [`ParallelCondition`](@ref).
9898
- `s::S`: a [`ParallelStrength`](@ref).
9999
"""
100-
struct NeverTreatedParallel{T<:Integer,C,S} <: TrendParallel{C,S}
100+
struct NeverTreatedParallel{C,S,T<:Integer} <: TrendParallel{C,S}
101101
e::Vector{T}
102102
c::C
103103
s::S
104104
NeverTreatedParallel(e::Vector{T}, c::C, s::S) where
105-
{T<:Integer,C<:ParallelCondition,S<:ParallelStrength} =
106-
new{T,C,S}(unique!(sort!(e)), c, s)
105+
{C<:ParallelCondition,S<:ParallelStrength,T<:Integer} =
106+
new{C,S,T}(unique!(sort!(e)), c, s)
107107
end
108108

109+
treated(pr::NeverTreatedParallel, x) = !(x in pr.e)
110+
109111
show(io::IO, pr::NeverTreatedParallel) =
110112
print(IOContext(io, :compact=>true), "NeverTreated{", pr.c, ",", pr.s, "}(", pr.e,")")
111113

@@ -156,7 +158,7 @@ A wrapper method of `nevertreated` for working with `@formula`.
156158
@unpack nevertreated
157159

158160
"""
159-
NotYetTreatedParallel{T<:Integer,C,S} <: TrendParallel{C,S}
161+
NotYetTreatedParallel{C,S,T<:Integer} <: TrendParallel{C,S}
160162
161163
Assume a parallel trends assumption holds between any group
162164
that received the treatment relatively early
@@ -174,17 +176,19 @@ See also [`notyettreated`](@ref).
174176
- never-treated groups are included and use indices with smaller values;
175177
- the sample has a rotating panel structure with periods overlapping with some others.
176178
"""
177-
struct NotYetTreatedParallel{T<:Integer,C,S} <: TrendParallel{C,S}
179+
struct NotYetTreatedParallel{C,S,T<:Integer} <: TrendParallel{C,S}
178180
e::Vector{T}
179181
emin::Union{Vector{T},Nothing}
180182
c::C
181183
s::S
182184
NotYetTreatedParallel(e::Vector{T}, emin::Union{Vector{T},Nothing}, c::C, s::S) where
183-
{T<:Integer,C<:ParallelCondition,S<:ParallelStrength} =
184-
new{T,C,S}(unique!(sort!(e)),
185+
{C<:ParallelCondition,S<:ParallelStrength,T<:Integer} =
186+
new{C,S,T}(unique!(sort!(e)),
185187
emin isa Nothing ? emin : unique!(sort!(emin)), c, s)
186188
end
187189

190+
treated(pr::NotYetTreatedParallel, x) = !(x in pr.e)
191+
188192
function show(io::IO, pr::NotYetTreatedParallel)
189193
print(IOContext(io, :compact=>true), "NotYetTreated{", pr.c, ",", pr.s, "}(", pr.e, ", ")
190194
print(IOContext(io, :compact=>true), pr.emin isa Nothing ? "NA" : pr.emin, ")")

src/procedures.jl

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,21 @@
11
"""
22
checkdata(args...)
33
4-
Check `data` is a `Table` and find rows with nonmissing values for variables.
4+
Check `data` is a `Table` and find valid rows for options `subset` and `weights`.
55
See also [`CheckData`](@ref).
6-
7-
# Returns
8-
- `vars::Vector{Symbol}`: column names of relevant variables.
9-
- `esample::BitArray`: Boolean indices for rows with nonmissing values for the variables.
106
"""
11-
function checkdata(data, tr::AbstractTreatment, pr::AbstractParallel,
12-
yterm::AbstractTerm, treatname::Symbol, xterms::TupleTerm,
13-
weights::Union{Symbol, Nothing}, subset::Union{AbstractVector, Nothing})
7+
function checkdata(data, subset::Union{AbstractVector, Nothing},
8+
weights::Union{Symbol, Nothing})
149

1510
istable(data) ||
1611
throw(ArgumentError("expect `data` being a `Table` while receiving a $(typeof(data))"))
1712

18-
vars = union([treatname], (termvars(t) for t in (tr, pr, yterm, xterms))...)
19-
esample = BitArray(all(v->!ismissing(getproperty(row, v)), vars) for row in rows(data))
20-
2113
if subset !== nothing
2214
length(subset) != size(data, 1) &&
2315
throw(DimensionMismatch("`data` of $(size(data, 1)) rows cannot be matched with `subset` vector of $(length(subset)) elements"))
24-
esample .&= .!ismissing.(subset) .& subset
16+
esample = .!ismissing.(subset) .& subset
17+
else
18+
esample = trues(size(data, 1))
2519
end
2620

2721
if weights !== nothing
@@ -30,18 +24,89 @@ function checkdata(data, tr::AbstractTreatment, pr::AbstractParallel,
3024
end
3125

3226
sum(esample) == 0 && error("no nonmissing data")
33-
34-
return (vars=vars, esample=esample)
27+
return (esample=esample,), false
3528
end
3629

3730
"""
38-
CheckData
31+
CheckData <: StatsStep
3932
40-
A [`StatsStep`](@ref) that calls [`DiffinDiffsBase.checkdata`](@ref)
33+
Call [`DiffinDiffsBase.checkdata`](@ref)
4134
for some preliminary checks of the input data.
4235
"""
4336
const CheckData = StatsStep{:CheckData, typeof(checkdata)}
4437

45-
namedargs(::CheckData) = (data=nothing, tr=nothing, pr=nothing, yterm=nothing,
46-
treatname=nothing, xterms=nothing, weights=nothing, subset=nothing)
38+
namedargs(::CheckData) = (data=nothing, subset=nothing, weights=nothing)
39+
40+
function _overlaptime(tr::DynamicTreatment, tr_rows::BitArray, data)
41+
control_time = Set(view(getcolumn(data, tr.time), .!tr_rows))
42+
treated_time = Set(view(getcolumn(data, tr.time), tr_rows))
43+
return intersect(control_time, treated_time), control_time, treated_time
44+
end
45+
46+
function overlap!(esample::BitArray, tr_rows::BitArray, tr::DynamicTreatment,
47+
::NeverTreatedParallel{Unconditional}, treatname::Symbol, data)
48+
overlap_time, control_time, treated_time = _overlaptime(tr, tr_rows, data)
49+
length(control_time)==length(treated_time)==length(overlap_time) ||
50+
(esample .&= getcolumn(data, tr.time).∈(overlap_time,))
51+
tr_rows .&= esample
52+
end
53+
54+
function overlap!(esample::BitArray, tr_rows::BitArray, tr::DynamicTreatment,
55+
pr::NotYetTreatedParallel{Unconditional}, treatname::Symbol, data)
56+
overlap_time, _c, _t = _overlaptime(tr, tr_rows, data)
57+
timetype = eltype(overlap_time)
58+
if timetype <: Integer
59+
emin = pr.emin === nothing ? minimum(pr.e) : pr.emin[1]
60+
valid_cohort = filter(x -> x < emin || x in pr.e, overlap_time)
61+
filter!(x -> x < emin, overlap_time)
62+
esample .&= (getcolumn(data, tr.time).∈(overlap_time,)) .&
63+
(getcolumn(data, treatname).∈(valid_cohort,))
64+
end
65+
tr_rows .&= esample
66+
end
67+
68+
"""
69+
checkvars(args...)
70+
71+
Return rows with observations that are nonmissing and satisfy the overlap condition
72+
and rows for observations from treated units.
73+
See also [`CheckVars`](@ref).
74+
"""
75+
function checkvars(data, tr::AbstractTreatment, pr::AbstractParallel,
76+
yterm::AbstractTerm, treatname::Symbol, treatintterms::TupleTerm,
77+
xterms::TupleTerm, esample::BitArray)
78+
79+
treatvars = union([treatname], (termvars(t) for t in (tr, pr, treatintterms))...)
80+
for v in treatvars
81+
eltype(getcolumn(data, v)) <: Union{Missing, Integer} ||
82+
throw(ArgumentError("data column $v has unaccepted element type"))
83+
end
84+
# Values of treatintterms from units in control groups are ignored
85+
allvars = union(treatvars, (termvars(t) for t in (yterm, xterms))...)
86+
treatedvars = setdiff(allvars, termvars(treatintterms))
87+
tr_rows = falses(length(esample))
88+
@inbounds for i in eachindex(esample)
89+
if esample[i]
90+
if treated(pr, getcolumn(data, treatname)[i])
91+
esample[i] = all(v->!ismissing(getcolumn(data, v)[i]), treatedvars)
92+
esample[i] && (tr_rows[i] = true)
93+
else
94+
esample[i] = all(v->!ismissing(getcolumn(data, v)[i]), allvars)
95+
end
96+
end
97+
end
98+
99+
overlap!(esample, tr_rows, tr, pr, treatname, data)
100+
sum(esample) == 0 && error("no nonmissing data")
101+
return (esample=esample, tr_rows=tr_rows), false
102+
end
103+
104+
"""
105+
CheckVars <: StatsStep
106+
107+
Call [`DiffinDiffsBase.checkvars`](@ref) for obtaining valid rows from relevant columns.
108+
"""
109+
const CheckVars = StatsStep{:CheckVars, typeof(checkvars)}
47110

111+
namedargs(::CheckVars) = (data=nothing, tr=nothing, pr=nothing,
112+
yterm=nothing, treatname=nothing, treatintterms=(), xterms=(), esample=nothing)

0 commit comments

Comments
 (0)