Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit 04ceeb3

Browse files
authored
Add SolveLeastSquaresWeights, AggregatedRegBasedDIDResult and agg (#6)
1 parent c8c0f82 commit 04ceeb3

File tree

7 files changed

+444
-110
lines changed

7 files changed

+444
-110
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1616
Vcov = "ec2bfdc2-55df-4fc9-b9ae-4958c2cf2486"
1717

1818
[compat]
19-
DiffinDiffsBase = "0.2.1"
19+
DataFrames = "0.22"
20+
DiffinDiffsBase = "0.3"
2021
FixedEffectModels = "1.4"
2122
FixedEffects = "2"
2223
Reexport = "0.2, 1"

src/InteractionWeightedDIDs.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@ using Tables
1313
using Tables: getcolumn, columnnames
1414
using Vcov
1515
@reexport using DiffinDiffsBase
16-
using DiffinDiffsBase: TimeType, termvars, isintercept, parse_intercept!, _treatnames
16+
using DiffinDiffsBase: TimeType, termvars, isintercept, parse_intercept!,
17+
_count!, _groupfind, _treatnames, _parse_bycells!, _parse_subset
1718

1819
import Base: show
1920
import DiffinDiffsBase: required, default, transformed, combinedargs, copyargs,
20-
valid_didargs, result, _count!
21+
valid_didargs, result, vce, nobs, outcomename, weights, treatnames, dof_residual, agg
2122
import FixedEffectModels: has_fe
2223

2324
export Vcov,
@@ -30,10 +31,12 @@ export CheckVcov,
3031
MakeTreatCols,
3132
SolveLeastSquares,
3233
EstVcov,
34+
SolveLeastSquaresWeights,
3335

3436
RegressionBasedDID,
3537
Reg,
36-
RegressionBasedDIDResult
38+
RegressionBasedDIDResult,
39+
AggregatedRegBasedDIDResult
3740

3841
include("utils.jl")
3942
include("procedures.jl")

src/did.jl

Lines changed: 101 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Estimation procedure for regression-based difference-in-differences.
55
"""
66
const RegressionBasedDID = DiffinDiffsEstimator{:RegressionBasedDID,
77
Tuple{CheckData, GroupTerms, CheckVcov, CheckVars, CheckFEs, MakeWeights, MakeFESolver,
8-
MakeYXCols, MakeTreatCols, SolveLeastSquares, EstVcov}}
8+
MakeYXCols, MakeTreatCols, SolveLeastSquares, EstVcov, SolveLeastSquaresWeights}}
99

1010
const Reg = RegressionBasedDID
1111

@@ -14,6 +14,7 @@ function valid_didargs(d::Type{Reg}, ::DynamicTreatment{SharpDesign},
1414
name = get(args, :name, "")::String
1515
treatintterms = haskey(args, :treatintterms) ? args[:treatintterms] : TermSet()
1616
xterms = haskey(args, :xterms) ? args[:xterms] : TermSet()
17+
solvelsweights = haskey(args, :lswtnames) || get(args, :solvelsweights, false)::Bool
1718
ntargs = (data=args[:data],
1819
tr=args[:tr]::DynamicTreatment{SharpDesign},
1920
pr=args[:pr]::TrendParallel{Unconditional, Exact},
@@ -29,15 +30,22 @@ function valid_didargs(d::Type{Reg}, ::DynamicTreatment{SharpDesign},
2930
contrasts=get(args, :contrasts, nothing)::Union{Dict{Symbol,Any},Nothing},
3031
fetol=get(args, :fetol, 1e-8)::Float64,
3132
femaxiter=get(args, :femaxiter, 10000)::Int,
32-
cohortinteracted=get(args, :cohortinteracted, true)::Bool)
33+
cohortinteracted=get(args, :cohortinteracted, true)::Bool,
34+
solvelsweights=solvelsweights::Bool,
35+
lswtnames=get(args, :lswtnames, ()))
3336
return name, d, ntargs
3437
end
3538

36-
struct RegressionBasedDIDResult{CohortInteracted} <: DIDResult
39+
"""
40+
RegressionBasedDIDResult{TR<:AbstractTreatment, CohortInteracted} <: DIDResult
41+
42+
Estimation results from regression-based difference-in-differences.
43+
"""
44+
struct RegressionBasedDIDResult{TR<:AbstractTreatment, CohortInteracted} <: DIDResult
3745
coef::Vector{Float64}
3846
vcov::Matrix{Float64}
3947
vce::CovarianceEstimator
40-
tr::AbstractTreatment
48+
tr::TR
4149
pr::AbstractParallel
4250
cellweights::Vector{Float64}
4351
cellcounts::Vector{Int}
@@ -60,6 +68,10 @@ struct RegressionBasedDIDResult{CohortInteracted} <: DIDResult
6068
nfeiterations::Union{Int, Nothing}
6169
feconverged::Union{Bool, Nothing}
6270
nfesingledropped::Int
71+
lsweights::Union{TableIndexedMatrix{Float64, Matrix{Float64}, VecColumnTable, VecColumnTable}, Nothing}
72+
ycellmeans::Union{Vector{Float64}, Nothing}
73+
ycellweights::Union{Vector{Float64}, Nothing}
74+
ycellcounts::Union{Vector{Int}, Nothing}
6375
end
6476

6577
function result(::Type{Reg}, @nospecialize(nt::NamedTuple))
@@ -68,12 +80,13 @@ function result(::Type{Reg}, @nospecialize(nt::NamedTuple))
6880
cnames = _treatnames(nt.treatcells)
6981
cnames = append!(cnames, coefnames.(nt.xterms))[nt.basecols]
7082
coefinds = Dict(cnames .=> 1:length(cnames))
71-
didresult = RegressionBasedDIDResult{nt.cohortinteracted}(
83+
didresult = RegressionBasedDIDResult{typeof(nt.tr), nt.cohortinteracted}(
7284
nt.coef, nt.vcov_mat, nt.vce, nt.tr, nt.pr, nt.cellweights, nt.cellcounts,
7385
nt.esample, sum(nt.esample), nt.dof_resid, nt.F, nt.p,
7486
yname, cnames, coefinds, nt.treatcells, nt.treatname, nt.yxterms,
7587
yterm, nt.xterms, nt.contrasts, nt.weightname,
76-
nt.fenames, nt.nfeiterations, nt.feconverged, nt.nsingle)
88+
nt.fenames, nt.nfeiterations, nt.feconverged, nt.nsingle,
89+
nt.lsweights, nt.ycellmeans, nt.ycellweights, nt.ycellcounts)
7790
return merge(nt, (result=didresult,))
7891
end
7992

@@ -96,22 +109,22 @@ _nunique(t, s::Symbol) = length(unique(getproperty(t, s)))
96109
_excluded_rel_str(tr::DynamicTreatment) =
97110
isempty(tr.exc) ? "none" : join(string.(tr.exc), " ")
98111

99-
_treat_info(r::RegressionBasedDIDResult{true}, tr::DynamicTreatment) = (
112+
_treat_info(r::RegressionBasedDIDResult{<:DynamicTreatment, true}) = (
100113
"Number of cohorts" => _nunique(r.treatcells, r.treatname),
101114
"Interactions within cohorts" => length(columnnames(r.treatcells)) - 2,
102115
"Relative time periods" => _nunique(r.treatcells, :rel),
103-
"Excluded periods" => NoQuote(_excluded_rel_str(tr))
116+
"Excluded periods" => NoQuote(_excluded_rel_str(r.tr))
104117
)
105118

106-
_treat_info(r::RegressionBasedDIDResult{false}, tr::DynamicTreatment) = (
119+
_treat_info(r::RegressionBasedDIDResult{<:DynamicTreatment, false}) = (
107120
"Relative time periods" => _nunique(r.treatcells, :rel),
108-
"Excluded periods" => NoQuote(_excluded_rel_str(tr))
121+
"Excluded periods" => NoQuote(_excluded_rel_str(r.tr))
109122
)
110123

111-
_treat_spec(r::RegressionBasedDIDResult{true}, tr::DynamicTreatment{SharpDesign}) =
124+
_treat_spec(r::RegressionBasedDIDResult{DynamicTreatment{SharpDesign}, true}) =
112125
"Cohort-interacted sharp dynamic specification"
113126

114-
_treat_spec(r::RegressionBasedDIDResult{false}, tr::DynamicTreatment{SharpDesign}) =
127+
_treat_spec(r::RegressionBasedDIDResult{DynamicTreatment{SharpDesign}, false}) =
115128
"Sharp dynamic specification"
116129

117130
_fe_info(r::RegressionBasedDIDResult) = (
@@ -126,12 +139,12 @@ function show(io::IO, ::MIME"text/plain", r::RegressionBasedDIDResult;
126139
halfwidth = div(totalwidth-interwidth, 2)
127140
top_info = _top_info(r)
128141
fe_info = has_fe(r) ? _fe_info(r) : ()
129-
tr_info = _treat_info(r, r.tr)
142+
tr_info = _treat_info(r)
130143
blocks = (top_info, tr_info, fe_info)
131144
fes = has_fe(r) ? join(string.(r.fenames), " ") : "none"
132145
fetitle = string("Fixed effects: ", fes)
133146
blocktitles = ("Summary of results: Regression-based DID",
134-
_treat_spec(r, r.tr), fetitle[1:min(totalwidth,length(fetitle))])
147+
_treat_spec(r), fetitle[1:min(totalwidth,length(fetitle))])
135148

136149
for (ib, b) in enumerate(blocks)
137150
println(io, repeat('', totalwidth))
@@ -148,3 +161,77 @@ function show(io::IO, ::MIME"text/plain", r::RegressionBasedDIDResult;
148161
end
149162
print(io, repeat('', totalwidth))
150163
end
164+
165+
"""
166+
AggregatedRegBasedDIDResult{P<:RegressionBasedDIDResult, I} <: AggregatedDIDResult{P}
167+
168+
Estimation results aggregated from a [`RegressionBasedDIDResult`](@ref).
169+
See also [`agg`](@ref).
170+
"""
171+
struct AggregatedRegBasedDIDResult{P<:RegressionBasedDIDResult, I} <: AggregatedDIDResult{P}
172+
parent::P
173+
inds::I
174+
coef::Vector{Float64}
175+
vcov::Matrix{Float64}
176+
coefweights::Matrix{Float64}
177+
cellweights::Vector{Float64}
178+
cellcounts::Vector{Int}
179+
coefnames::Vector{String}
180+
coefinds::Dict{String, Int}
181+
treatcells::VecColumnTable
182+
lsweights::Union{TableIndexedMatrix{Float64, Matrix{Float64}, VecColumnTable, VecColumnTable}, Nothing}
183+
end
184+
185+
function agg(r::RegressionBasedDIDResult{<:DynamicTreatment}, names=nothing;
186+
bys=nothing, subset=nothing)
187+
inds = subset === nothing ? Colon() : _parse_subset(r, subset, false)
188+
ptcells = treatcells(r)
189+
bycells = view(ptcells, inds)
190+
_parse_bycells!(getfield(bycells, :columns), ptcells, bys)
191+
names === nothing || (bycells = subcolumns(bycells, names, nomissing=false))
192+
193+
tcells, rows = cellrows(bycells, findcell(bycells))
194+
ncell = length(rows)
195+
pcf = view(treatcoef(r), inds)
196+
cweights = zeros(length(pcf), ncell)
197+
pcellweights = view(r.cellweights, inds)
198+
pcellcounts = view(r.cellcounts, inds)
199+
# Ensure the weights for each relative time always sum up to one
200+
rels = view(r.treatcells.rel, inds)
201+
for (i, rs) in enumerate(rows)
202+
if length(rs) > 1
203+
relgroups = _groupfind(view(rels, rs))
204+
for inds in values(relgroups)
205+
if length(inds) > 1
206+
cwts = view(pcellweights, view(rs, inds))
207+
cweights[view(rs, inds), i] .= cwts ./ sum(cwts)
208+
else
209+
cweights[rs[inds[1]], i] = 1.0
210+
end
211+
end
212+
else
213+
cweights[rs[1], i] = 1.0
214+
end
215+
end
216+
cf = cweights' * pcf
217+
v = cweights' * view(treatvcov(r), inds, inds) * cweights
218+
cellweights = [sum(pcellweights[rows[i]]) for i in 1:ncell]
219+
cellcounts = [sum(pcellcounts[rows[i]]) for i in 1:ncell]
220+
cnames = _treatnames(tcells)
221+
coefinds = Dict(cnames .=> keys(cnames))
222+
if r.lsweights === nothing
223+
lswt = nothing
224+
else
225+
lswtmat = view(r.lsweights.m, :, inds) * cweights
226+
lswt = TableIndexedMatrix(lswtmat, r.lsweights.r, tcells)
227+
end
228+
return AggregatedRegBasedDIDResult{typeof(r), typeof(inds)}(r, inds, cf, v, cweights,
229+
cellweights, cellcounts, cnames, coefinds, tcells, lswt)
230+
end
231+
232+
vce(r::AggregatedRegBasedDIDResult) = vce(parent(r))
233+
nobs(r::AggregatedRegBasedDIDResult) = nobs(parent(r))
234+
outcomename(r::AggregatedRegBasedDIDResult) = outcomename(parent(r))
235+
weights(r::AggregatedRegBasedDIDResult) = weights(parent(r))
236+
treatnames(r::AggregatedRegBasedDIDResult) = coefnames(r)
237+
dof_residual(r::AggregatedRegBasedDIDResult) = dof_residual(parent(r))

0 commit comments

Comments
 (0)