@@ -5,7 +5,7 @@ Estimation procedure for regression-based difference-in-differences.
55"""
66const RegressionBasedDID = DiffinDiffsEstimator{:RegressionBasedDID ,
77 Tuple{CheckData, GroupTerms, CheckVcov, CheckVars, CheckFEs, MakeWeights, MakeFESolver,
8- MakeYXCols, MakeTreatCols, SolveLeastSquares, EstVcov}}
8+ MakeYXCols, MakeTreatCols, SolveLeastSquares, EstVcov, SolveLeastSquaresWeights }}
99
1010const Reg = RegressionBasedDID
1111
@@ -14,6 +14,7 @@ function valid_didargs(d::Type{Reg}, ::DynamicTreatment{SharpDesign},
1414 name = get (args, :name , " " ):: String
1515 treatintterms = haskey (args, :treatintterms ) ? args[:treatintterms ] : TermSet ()
1616 xterms = haskey (args, :xterms ) ? args[:xterms ] : TermSet ()
17+ solvelsweights = haskey (args, :lswtnames ) || get (args, :solvelsweights , false ):: Bool
1718 ntargs = (data= args[:data ],
1819 tr= args[:tr ]:: DynamicTreatment{SharpDesign} ,
1920 pr= args[:pr ]:: TrendParallel{Unconditional, Exact} ,
@@ -29,15 +30,22 @@ function valid_didargs(d::Type{Reg}, ::DynamicTreatment{SharpDesign},
2930 contrasts= get (args, :contrasts , nothing ):: Union{Dict{Symbol,Any},Nothing} ,
3031 fetol= get (args, :fetol , 1e-8 ):: Float64 ,
3132 femaxiter= get (args, :femaxiter , 10000 ):: Int ,
32- cohortinteracted= get (args, :cohortinteracted , true ):: Bool )
33+ cohortinteracted= get (args, :cohortinteracted , true ):: Bool ,
34+ solvelsweights= solvelsweights:: Bool ,
35+ lswtnames= get (args, :lswtnames , ()))
3336 return name, d, ntargs
3437end
3538
36- struct RegressionBasedDIDResult{CohortInteracted} <: DIDResult
39+ """
40+ RegressionBasedDIDResult{TR<:AbstractTreatment, CohortInteracted} <: DIDResult
41+
42+ Estimation results from regression-based difference-in-differences.
43+ """
44+ struct RegressionBasedDIDResult{TR<: AbstractTreatment , CohortInteracted} <: DIDResult
3745 coef:: Vector{Float64}
3846 vcov:: Matrix{Float64}
3947 vce:: CovarianceEstimator
40- tr:: AbstractTreatment
48+ tr:: TR
4149 pr:: AbstractParallel
4250 cellweights:: Vector{Float64}
4351 cellcounts:: Vector{Int}
@@ -60,6 +68,10 @@ struct RegressionBasedDIDResult{CohortInteracted} <: DIDResult
6068 nfeiterations:: Union{Int, Nothing}
6169 feconverged:: Union{Bool, Nothing}
6270 nfesingledropped:: Int
71+ lsweights:: Union{TableIndexedMatrix{Float64, Matrix{Float64}, VecColumnTable, VecColumnTable}, Nothing}
72+ ycellmeans:: Union{Vector{Float64}, Nothing}
73+ ycellweights:: Union{Vector{Float64}, Nothing}
74+ ycellcounts:: Union{Vector{Int}, Nothing}
6375end
6476
6577function result (:: Type{Reg} , @nospecialize (nt:: NamedTuple ))
@@ -68,12 +80,13 @@ function result(::Type{Reg}, @nospecialize(nt::NamedTuple))
6880 cnames = _treatnames (nt. treatcells)
6981 cnames = append! (cnames, coefnames .(nt. xterms))[nt. basecols]
7082 coefinds = Dict (cnames .=> 1 : length (cnames))
71- didresult = RegressionBasedDIDResult {nt.cohortinteracted} (
83+ didresult = RegressionBasedDIDResult {typeof(nt.tr), nt.cohortinteracted} (
7284 nt. coef, nt. vcov_mat, nt. vce, nt. tr, nt. pr, nt. cellweights, nt. cellcounts,
7385 nt. esample, sum (nt. esample), nt. dof_resid, nt. F, nt. p,
7486 yname, cnames, coefinds, nt. treatcells, nt. treatname, nt. yxterms,
7587 yterm, nt. xterms, nt. contrasts, nt. weightname,
76- nt. fenames, nt. nfeiterations, nt. feconverged, nt. nsingle)
88+ nt. fenames, nt. nfeiterations, nt. feconverged, nt. nsingle,
89+ nt. lsweights, nt. ycellmeans, nt. ycellweights, nt. ycellcounts)
7790 return merge (nt, (result= didresult,))
7891end
7992
@@ -96,22 +109,22 @@ _nunique(t, s::Symbol) = length(unique(getproperty(t, s)))
96109_excluded_rel_str (tr:: DynamicTreatment ) =
97110 isempty (tr. exc) ? " none" : join (string .(tr. exc), " " )
98111
99- _treat_info (r:: RegressionBasedDIDResult{true} , tr :: DynamicTreatment ) = (
112+ _treat_info (r:: RegressionBasedDIDResult{<: DynamicTreatment, true} ) = (
100113 " Number of cohorts" => _nunique (r. treatcells, r. treatname),
101114 " Interactions within cohorts" => length (columnnames (r. treatcells)) - 2 ,
102115 " Relative time periods" => _nunique (r. treatcells, :rel ),
103- " Excluded periods" => NoQuote (_excluded_rel_str (tr))
116+ " Excluded periods" => NoQuote (_excluded_rel_str (r . tr))
104117)
105118
106- _treat_info (r:: RegressionBasedDIDResult{false} , tr :: DynamicTreatment ) = (
119+ _treat_info (r:: RegressionBasedDIDResult{<: DynamicTreatment, false} ) = (
107120 " Relative time periods" => _nunique (r. treatcells, :rel ),
108- " Excluded periods" => NoQuote (_excluded_rel_str (tr))
121+ " Excluded periods" => NoQuote (_excluded_rel_str (r . tr))
109122)
110123
111- _treat_spec (r:: RegressionBasedDIDResult{true} , tr :: DynamicTreatment{SharpDesign} ) =
124+ _treat_spec (r:: RegressionBasedDIDResult{DynamicTreatment{SharpDesign}, true } ) =
112125 " Cohort-interacted sharp dynamic specification"
113126
114- _treat_spec (r:: RegressionBasedDIDResult{false} , tr :: DynamicTreatment{SharpDesign} ) =
127+ _treat_spec (r:: RegressionBasedDIDResult{DynamicTreatment{SharpDesign}, false } ) =
115128 " Sharp dynamic specification"
116129
117130_fe_info (r:: RegressionBasedDIDResult ) = (
@@ -126,12 +139,12 @@ function show(io::IO, ::MIME"text/plain", r::RegressionBasedDIDResult;
126139 halfwidth = div (totalwidth- interwidth, 2 )
127140 top_info = _top_info (r)
128141 fe_info = has_fe (r) ? _fe_info (r) : ()
129- tr_info = _treat_info (r, r . tr )
142+ tr_info = _treat_info (r)
130143 blocks = (top_info, tr_info, fe_info)
131144 fes = has_fe (r) ? join (string .(r. fenames), " " ) : " none"
132145 fetitle = string (" Fixed effects: " , fes)
133146 blocktitles = (" Summary of results: Regression-based DID" ,
134- _treat_spec (r, r . tr ), fetitle[1 : min (totalwidth,length (fetitle))])
147+ _treat_spec (r), fetitle[1 : min (totalwidth,length (fetitle))])
135148
136149 for (ib, b) in enumerate (blocks)
137150 println (io, repeat (' ─' , totalwidth))
@@ -148,3 +161,77 @@ function show(io::IO, ::MIME"text/plain", r::RegressionBasedDIDResult;
148161 end
149162 print (io, repeat (' ─' , totalwidth))
150163end
164+
165+ """
166+ AggregatedRegBasedDIDResult{P<:RegressionBasedDIDResult, I} <: AggregatedDIDResult{P}
167+
168+ Estimation results aggregated from a [`RegressionBasedDIDResult`](@ref).
169+ See also [`agg`](@ref).
170+ """
171+ struct AggregatedRegBasedDIDResult{P<: RegressionBasedDIDResult , I} <: AggregatedDIDResult{P}
172+ parent:: P
173+ inds:: I
174+ coef:: Vector{Float64}
175+ vcov:: Matrix{Float64}
176+ coefweights:: Matrix{Float64}
177+ cellweights:: Vector{Float64}
178+ cellcounts:: Vector{Int}
179+ coefnames:: Vector{String}
180+ coefinds:: Dict{String, Int}
181+ treatcells:: VecColumnTable
182+ lsweights:: Union{TableIndexedMatrix{Float64, Matrix{Float64}, VecColumnTable, VecColumnTable}, Nothing}
183+ end
184+
185+ function agg (r:: RegressionBasedDIDResult{<:DynamicTreatment} , names= nothing ;
186+ bys= nothing , subset= nothing )
187+ inds = subset === nothing ? Colon () : _parse_subset (r, subset, false )
188+ ptcells = treatcells (r)
189+ bycells = view (ptcells, inds)
190+ _parse_bycells! (getfield (bycells, :columns ), ptcells, bys)
191+ names === nothing || (bycells = subcolumns (bycells, names, nomissing= false ))
192+
193+ tcells, rows = cellrows (bycells, findcell (bycells))
194+ ncell = length (rows)
195+ pcf = view (treatcoef (r), inds)
196+ cweights = zeros (length (pcf), ncell)
197+ pcellweights = view (r. cellweights, inds)
198+ pcellcounts = view (r. cellcounts, inds)
199+ # Ensure the weights for each relative time always sum up to one
200+ rels = view (r. treatcells. rel, inds)
201+ for (i, rs) in enumerate (rows)
202+ if length (rs) > 1
203+ relgroups = _groupfind (view (rels, rs))
204+ for inds in values (relgroups)
205+ if length (inds) > 1
206+ cwts = view (pcellweights, view (rs, inds))
207+ cweights[view (rs, inds), i] .= cwts ./ sum (cwts)
208+ else
209+ cweights[rs[inds[1 ]], i] = 1.0
210+ end
211+ end
212+ else
213+ cweights[rs[1 ], i] = 1.0
214+ end
215+ end
216+ cf = cweights' * pcf
217+ v = cweights' * view (treatvcov (r), inds, inds) * cweights
218+ cellweights = [sum (pcellweights[rows[i]]) for i in 1 : ncell]
219+ cellcounts = [sum (pcellcounts[rows[i]]) for i in 1 : ncell]
220+ cnames = _treatnames (tcells)
221+ coefinds = Dict (cnames .=> keys (cnames))
222+ if r. lsweights === nothing
223+ lswt = nothing
224+ else
225+ lswtmat = view (r. lsweights. m, :, inds) * cweights
226+ lswt = TableIndexedMatrix (lswtmat, r. lsweights. r, tcells)
227+ end
228+ return AggregatedRegBasedDIDResult {typeof(r), typeof(inds)} (r, inds, cf, v, cweights,
229+ cellweights, cellcounts, cnames, coefinds, tcells, lswt)
230+ end
231+
232+ vce (r:: AggregatedRegBasedDIDResult ) = vce (parent (r))
233+ nobs (r:: AggregatedRegBasedDIDResult ) = nobs (parent (r))
234+ outcomename (r:: AggregatedRegBasedDIDResult ) = outcomename (parent (r))
235+ weights (r:: AggregatedRegBasedDIDResult ) = weights (parent (r))
236+ treatnames (r:: AggregatedRegBasedDIDResult ) = coefnames (r)
237+ dof_residual (r:: AggregatedRegBasedDIDResult ) = dof_residual (parent (r))
0 commit comments