Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit b276c26

Browse files
committed
Improve treatment and parallel types
1 parent 6af12d3 commit b276c26

File tree

11 files changed

+201
-170
lines changed

11 files changed

+201
-170
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
99
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1010
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1111
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
12-
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1312
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
1413
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1514

@@ -19,7 +18,6 @@ Combinatorics = "1"
1918
MacroTools = "0.5"
2019
Reexport = "0.2, 1"
2120
SplitApplyCombine = "1.1"
22-
StatsBase = "0.33"
2321
StatsModels = "0.6.18"
2422
Tables = "1.2"
2523
julia = "1.3"

src/DiffinDiffsBase.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@ using Combinatorics: combinations
44
using CSV: File
55
using MacroTools: @capture, isexpr, postwalk
66
using Reexport
7-
using StatsBase
87
@reexport using StatsModels
98
using StatsModels: TupleTerm
109
using SplitApplyCombine: groupfind, groupview
11-
using Tables: columntable, istable, rows, columns, getcolumn
10+
using Tables: istable, getcolumn
1211

1312
import Base: ==, show, union
1413
import Base: eltype, firstindex, lastindex, getindex, iterate, length, sym_in
@@ -19,14 +18,11 @@ export TupleTerm
1918

2019
export @fieldequal,
2120
eachterm,
22-
c,
23-
unpack,
24-
kwarg,
21+
cb,
2522
@unpack,
2623
,
2724
exampledata,
2825

29-
EleOrVec,
3026
TreatmentSharpness,
3127
SharpDesign,
3228
sharp,
@@ -48,7 +44,7 @@ export @fieldequal,
4844
nevertreated,
4945
NotYetTreatedParallel,
5046
notyettreated,
51-
treated,
47+
istreated,
5248

5349
TreatmentTerm,
5450
treat,

src/parallels.jl

Lines changed: 60 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Assume some notion of parallel holds without conditions.
1616
"""
1717
struct Unconditional <: ParallelCondition end
1818

19-
show(io::IO, C::Unconditional) =
19+
show(io::IO, ::Unconditional) =
2020
get(io, :compact, false) ? print(io, "U") : print(io, "Unconditional")
2121

2222
"""
@@ -50,7 +50,7 @@ Assume some notion of parallel holds exactly.
5050
"""
5151
struct Exact <: ParallelStrength end
5252

53-
show(io::IO, S::Exact) =
53+
show(io::IO, ::Exact) =
5454
get(io, :compact, false) ? print(io, "P") : print(io, "Parallel")
5555

5656
"""
@@ -68,11 +68,11 @@ Supertype for all types assuming some notion of parallel holds approximately.
6868
abstract type Approximate <: ParallelStrength end
6969

7070
"""
71-
AbstractParallel{C<:ParallelCondition,S<:ParallelStrength}
71+
AbstractParallel{C<:ParallelCondition, S<:ParallelStrength}
7272
7373
Supertype for all parallel types.
7474
"""
75-
abstract type AbstractParallel{C<:ParallelCondition,S<:ParallelStrength} end
75+
abstract type AbstractParallel{C<:ParallelCondition, S<:ParallelStrength} end
7676

7777
@fieldequal AbstractParallel
7878

@@ -85,31 +85,33 @@ assume a parallel trends assumption holds over all the relevant time periods.
8585
abstract type TrendParallel{C,S} <: AbstractParallel{C,S} end
8686

8787
"""
88-
NeverTreatedParallel{C,S,T<:Integer} <: TrendParallel{C,S}
88+
NeverTreatedParallel{C,S,T<:Tuple} <: TrendParallel{C,S}
8989
9090
Assume a parallel trends assumption holds between any group
9191
that received the treatment during the sample periods
9292
and a group that did not receive any treatment in any sample period.
9393
See also [`nevertreated`](@ref).
9494
9595
# Fields
96-
- `e::Vector{T}`: group indices for units that did not receive any treatment.
97-
- `c::C`: a [`ParallelCondition`](@ref).
98-
- `s::S`: a [`ParallelStrength`](@ref).
96+
- `e::T`: group indices for units that did not receive any treatment.
97+
- `c::C`: an instance of [`ParallelCondition`](@ref).
98+
- `s::S`: an instance of [`ParallelStrength`](@ref).
9999
"""
100-
struct NeverTreatedParallel{C,S,T<:Integer} <: TrendParallel{C,S}
101-
e::Vector{T}
100+
struct NeverTreatedParallel{C,S,T<:Tuple} <: TrendParallel{C,S}
101+
e::T
102102
c::C
103103
s::S
104-
NeverTreatedParallel(e::Vector{T}, c::C, s::S) where
105-
{C<:ParallelCondition,S<:ParallelStrength,T<:Integer} =
106-
new{C,S,T}(unique!(sort!(e)), c, s)
104+
function NeverTreatedParallel(e, c::ParallelCondition, s::ParallelStrength)
105+
e = (unique!(sort!([e...]))...,)
106+
isempty(e) && error("field `e` cannot be empty")
107+
return new{typeof(c),typeof(s),typeof(e)}(e, c, s)
108+
end
107109
end
108110

109-
treated(pr::NeverTreatedParallel, x) = !(x in pr.e)
111+
istreated(pr::NeverTreatedParallel, x) = !(x in pr.e)
110112

111113
show(io::IO, pr::NeverTreatedParallel) =
112-
print(IOContext(io, :compact=>true), "NeverTreated{", pr.c, ",", pr.s, "}(", pr.e,")")
114+
print(IOContext(io, :compact=>true), "NeverTreated{", pr.c, ",", pr.s, "}", pr.e)
113115

114116
function show(io::IO, ::MIME"text/plain", pr::NeverTreatedParallel)
115117
println(io, pr.s, " trends with any never-treated group:")
@@ -118,11 +120,10 @@ function show(io::IO, ::MIME"text/plain", pr::NeverTreatedParallel)
118120
end
119121

120122
"""
121-
nevertreated(itr, c::ParallelCondition, s::ParallelStrength)
122-
nevertreated(itr; c=Unconditional(), s=Exact())
123+
nevertreated(e, c::ParallelCondition, s::ParallelStrength)
124+
nevertreated(e; c=Unconditional(), s=Exact())
123125
124-
Construct a [`NeverTreatedParallel`](@ref) with field `e`
125-
set by unique elements in `itr`.
126+
Construct a [`NeverTreatedParallel`](@ref) with fields set by the arguments.
126127
By default, `c` is set as [`Unconditional()`](@ref)
127128
and `s` is set as [`Exact()`](@ref).
128129
When working with `@formula`,
@@ -132,23 +133,23 @@ a wrapper method of `nevertreated` calls this method.
132133
```jldoctest; setup = :(using DiffinDiffsBase)
133134
julia> nevertreated(-1)
134135
Parallel trends with any never-treated group:
135-
Never-treated groups: [-1]
136+
Never-treated groups: (-1,)
136137
137138
julia> typeof(nevertreated(-1))
138-
NeverTreatedParallel{Int64,Unconditional,Exact}
139+
NeverTreatedParallel{Unconditional,Exact,Tuple{Int64}}
139140
140141
julia> nevertreated([-1, 0])
141142
Parallel trends with any never-treated group:
142-
Never-treated groups: [-1, 0]
143+
Never-treated groups: (-1, 0)
143144
144145
julia> nevertreated([-1, 0]) == nevertreated(-1:0) == nevertreated(Set([-1, 0]))
145146
true
146147
```
147148
"""
148-
nevertreated(itr, c::ParallelCondition, s::ParallelStrength) =
149-
NeverTreatedParallel([itr...], c, s)
150-
nevertreated(itr; c::ParallelCondition=Unconditional(), s::ParallelStrength=Exact()) =
151-
NeverTreatedParallel([itr...], c, s)
149+
nevertreated(e, c::ParallelCondition, s::ParallelStrength) =
150+
NeverTreatedParallel(e, c, s)
151+
nevertreated(e; c::ParallelCondition=Unconditional(), s::ParallelStrength=Exact()) =
152+
NeverTreatedParallel(e, c, s)
152153

153154
"""
154155
nevertreated(ts::AbstractTerm...)
@@ -158,56 +159,57 @@ A wrapper method of `nevertreated` for working with `@formula`.
158159
@unpack nevertreated
159160

160161
"""
161-
NotYetTreatedParallel{C,S,T<:Integer} <: TrendParallel{C,S}
162+
NotYetTreatedParallel{C,S,T1<:Tuple,T2<:Tuple} <: TrendParallel{C,S}
162163
163164
Assume a parallel trends assumption holds between any group
164165
that received the treatment relatively early
165166
and any group that received the treatment relatively late (or never receved).
166167
See also [`notyettreated`](@ref).
167168
168169
# Fields
169-
- `e::Vector{T}`: group indices for units that received the treatment relatively late.
170-
- `emin::Union{Vector{T},Nothing}`: user-specified period(s) when units in a group in `e` started to receive treatment.
171-
- `c::C`: a [`ParallelCondition`](@ref).
172-
- `s::S`: a [`ParallelStrength`](@ref).
170+
- `e::T1`: group indices for units that received the treatment relatively late.
171+
- `ecut::T2`: user-specified period(s) when units in a group in `e` started to receive treatment.
172+
- `c::C`: an instance of [`ParallelCondition`](@ref).
173+
- `s::S`: an instance of [`ParallelStrength`](@ref).
173174
174175
!!! note
175-
`emin` could be different from `minimum(e)` if
176+
`ecut` could be different from `minimum(e)` if
176177
- never-treated groups are included and use indices with smaller values;
177178
- the sample has a rotating panel structure with periods overlapping with some others.
178179
"""
179-
struct NotYetTreatedParallel{C,S,T<:Integer} <: TrendParallel{C,S}
180-
e::Vector{T}
181-
emin::Union{Vector{T},Nothing}
180+
struct NotYetTreatedParallel{C,S,T1<:Tuple,T2<:Tuple} <: TrendParallel{C,S}
181+
e::T1
182+
ecut::T2
182183
c::C
183184
s::S
184-
NotYetTreatedParallel(e::Vector{T}, emin::Union{Vector{T},Nothing}, c::C, s::S) where
185-
{C<:ParallelCondition,S<:ParallelStrength,T<:Integer} =
186-
new{C,S,T}(unique!(sort!(e)),
187-
emin isa Nothing ? emin : unique!(sort!(emin)), c, s)
185+
function NotYetTreatedParallel(e, ecut, c::ParallelCondition, s::ParallelStrength)
186+
e = (unique!(sort!([e...]))...,)
187+
isempty(e) && error("field `e` cannot be empty")
188+
ecut = (unique!(sort!([ecut...]))...,)
189+
isempty(ecut) && error("field `ecut` cannot be empty")
190+
return new{typeof(c),typeof(s),typeof(e),typeof(ecut)}(e, ecut, c, s)
191+
end
188192
end
189193

190-
treated(pr::NotYetTreatedParallel, x) = !(x in pr.e)
194+
istreated(pr::NotYetTreatedParallel, x) = !(x in pr.e)
191195

192196
function show(io::IO, pr::NotYetTreatedParallel)
193-
print(IOContext(io, :compact=>true), "NotYetTreated{", pr.c, ",", pr.s, "}(", pr.e, ", ")
194-
print(IOContext(io, :compact=>true), pr.emin isa Nothing ? "NA" : pr.emin, ")")
197+
print(IOContext(io, :compact=>true), "NotYetTreated{", pr.c, ",", pr.s, "}", pr.e)
195198
end
196199

197200
function show(io::IO, ::MIME"text/plain", pr::NotYetTreatedParallel)
198201
println(io, pr.s, " trends with any not-yet-treated group:")
199202
println(io, " Not-yet-treated groups: ", pr.e)
200-
print(io, " Treated since: ", pr.emin isa Nothing ? "not specified" : pr.emin)
203+
print(io, " Treated since: ", pr.ecut)
201204
pr.c isa Unconditional || print(io, "\n ", pr.c)
202205
end
203206

204207
"""
205-
notyettreated(itr, emin, c::ParallelCondition, s::ParallelStrength)
206-
notyettreated(itr, emin=nothing; c=Unconditional(), s=Exact())
208+
notyettreated(e, ecut, c::ParallelCondition, s::ParallelStrength)
209+
notyettreated(e, ecut=minimum(e); c=Unconditional(), s=Exact())
207210
208211
Construct a [`NotYetTreatedParallel`](@ref) with
209-
elements in `itr` for field `e` and optional `emin`
210-
if `emin` is not `minimum(e)`.
212+
fields set by the arguments.
211213
By default, `c` is set as [`Unconditional()`](@ref)
212214
and `s` is set as [`Exact()`](@ref).
213215
When working with `@formula`,
@@ -217,28 +219,28 @@ a wrapper method of `notyettreated` calls this method.
217219
```jldoctest; setup = :(using DiffinDiffsBase)
218220
julia> notyettreated(5)
219221
Parallel trends with any not-yet-treated group:
220-
Not-yet-treated groups: [5]
221-
Treated since: not specified
222+
Not-yet-treated groups: (5,)
223+
Treated since: (5,)
222224
223225
julia> typeof(notyettreated(5))
224-
NotYetTreatedParallel{Int64,Unconditional,Exact}
226+
NotYetTreatedParallel{Unconditional,Exact,Tuple{Int64},Tuple{Int64}}
225227
226228
julia> notyettreated([-1, 5, 6], 5)
227229
Parallel trends with any not-yet-treated group:
228-
Not-yet-treated groups: [-1, 5, 6]
229-
Treated since: [5]
230+
Not-yet-treated groups: (-1, 5, 6)
231+
Treated since: (5,)
230232
231233
julia> notyettreated([4, 5, 6], [4, 5, 6])
232234
Parallel trends with any not-yet-treated group:
233-
Not-yet-treated groups: [4, 5, 6]
234-
Treated since: [4, 5, 6]
235+
Not-yet-treated groups: (4, 5, 6)
236+
Treated since: (4, 5, 6)
235237
```
236238
"""
237-
notyettreated(itr, emin, c::ParallelCondition, s::ParallelStrength) =
238-
NotYetTreatedParallel([itr...], emin isa Nothing ? emin : [emin...], c, s)
239-
notyettreated(itr, emin=nothing;
239+
notyettreated(e, ecut, c::ParallelCondition, s::ParallelStrength) =
240+
NotYetTreatedParallel(e, ecut, c, s)
241+
notyettreated(e, ecut=minimum(e);
240242
c::ParallelCondition=Unconditional(), s::ParallelStrength=Exact()) =
241-
NotYetTreatedParallel([itr...], emin isa Nothing ? emin : [emin...], c, s)
243+
NotYetTreatedParallel(e, ecut, c, s)
242244

243245
"""
244246
notyettreated(ts::AbstractTerm...)

src/procedures.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function checkdata(data, subset::Union{AbstractVector, Nothing},
1919
end
2020

2121
if weights !== nothing
22-
colweights = getcolumn(columns(data), weights)
22+
colweights = getcolumn(data, weights)
2323
esample .&= .!ismissing.(colweights) .& (colweights .> 0)
2424
end
2525

@@ -56,23 +56,23 @@ function overlap!(esample::BitArray, tr_rows::BitArray, tr::DynamicTreatment,
5656
overlap_time, _c, _t = _overlaptime(tr, tr_rows, data)
5757
timetype = eltype(overlap_time)
5858
if timetype <: Integer
59-
emin = pr.emin === nothing ? minimum(pr.e) : pr.emin[1]
60-
valid_cohort = filter(x -> x < emin || x in pr.e, overlap_time)
61-
filter!(x -> x < emin, overlap_time)
59+
ecut = pr.ecut === nothing ? minimum(pr.e) : pr.ecut[1]
60+
valid_cohort = filter(x -> x < ecut || x in pr.e, overlap_time)
61+
filter!(x -> x < ecut, overlap_time)
6262
esample .&= (getcolumn(data, tr.time).∈(overlap_time,)) .&
6363
(getcolumn(data, treatname).∈(valid_cohort,))
6464
end
6565
tr_rows .&= esample
6666
end
6767

6868
"""
69-
checkvars(args...)
69+
checkvars!(args...)
7070
71-
Return rows with observations that are nonmissing and satisfy the overlap condition
72-
and rows for observations from treated units.
71+
Exclude rows with missing data or violate the overlap condition
72+
and find rows with data from treated units.
7373
See also [`CheckVars`](@ref).
7474
"""
75-
function checkvars(data, tr::AbstractTreatment, pr::AbstractParallel,
75+
function checkvars!(data, tr::AbstractTreatment, pr::AbstractParallel,
7676
yterm::AbstractTerm, treatname::Symbol, treatintterms::TupleTerm,
7777
xterms::TupleTerm, esample::BitArray)
7878

@@ -87,7 +87,7 @@ function checkvars(data, tr::AbstractTreatment, pr::AbstractParallel,
8787
tr_rows = falses(length(esample))
8888
@inbounds for i in eachindex(esample)
8989
if esample[i]
90-
if treated(pr, getcolumn(data, treatname)[i])
90+
if istreated(pr, getcolumn(data, treatname)[i])
9191
esample[i] = all(v->!ismissing(getcolumn(data, v)[i]), treatedvars)
9292
esample[i] && (tr_rows[i] = true)
9393
else
@@ -104,9 +104,9 @@ end
104104
"""
105105
CheckVars <: StatsStep
106106
107-
Call [`DiffinDiffsBase.checkvars`](@ref) for obtaining valid rows from relevant columns.
107+
Call [`DiffinDiffsBase.checkvars!`](@ref) to exclude invalid rows for relevant variables.
108108
"""
109-
const CheckVars = StatsStep{:CheckVars, typeof(checkvars)}
109+
const CheckVars = StatsStep{:CheckVars, typeof(checkvars!)}
110110

111111
namedargs(::CheckVars) = (data=nothing, tr=nothing, pr=nothing,
112112
yterm=nothing, treatname=nothing, treatintterms=(), xterms=(), esample=nothing)

0 commit comments

Comments
 (0)