Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit 4c2ec2c

Browse files
authored
Add ContrastResult, contrast and post! (#11)
1 parent 03178c1 commit 4c2ec2c

File tree

4 files changed

+300
-21
lines changed

4 files changed

+300
-21
lines changed

src/InteractionWeightedDIDs.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ using DiffinDiffsBase: ValidTimeType, termvars, isintercept, parse_intercept!,
1919

2020
import Base: show
2121
import DiffinDiffsBase: required, default, transformed, combinedargs, copyargs,
22-
valid_didargs, result, vce, nobs, outcomename, weights, treatnames, dof_residual, agg
22+
valid_didargs, result, vce, treatment, nobs, outcomename, weights, treatnames,
23+
dof_residual, agg, post!, _parse_subset
2324
import FixedEffectModels: has_fe
2425

2526
export Vcov,
@@ -40,7 +41,11 @@ export CheckVcov,
4041
RegressionBasedDID,
4142
Reg,
4243
RegressionBasedDIDResult,
43-
AggregatedRegBasedDIDResult
44+
has_fe,
45+
AggregatedRegDIDResult,
46+
has_lsweights,
47+
ContrastResult,
48+
contrast
4449

4550
include("utils.jl")
4651
include("procedures.jl")

src/did.jl

Lines changed: 190 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Estimation procedure for regression-based difference-in-differences.
55
"""
66
const RegressionBasedDID = DiffinDiffsEstimator{:RegressionBasedDID,
7-
Tuple{CheckData, GroupTreatintterms, GroupXterms, CheckVcov, CheckVars, GroupSample,
7+
Tuple{CheckData, GroupTreatintterms, GroupXterms, GroupContrasts,
8+
CheckVcov, CheckVars, GroupSample,
89
ParseFEterms, GroupFEterms, MakeFEs, CheckFEs, MakeWeights, MakeFESolver,
910
MakeYXCols, MakeTreatCols, SolveLeastSquares, EstVcov, SolveLeastSquaresWeights}}
1011

@@ -26,9 +27,9 @@ function valid_didargs(d::Type{Reg}, ::DynamicTreatment{SharpDesign},
2627
vce=get(args, :vce, Vcov.RobustCovariance())::Vcov.CovarianceEstimator,
2728
treatintterms=treatintterms::TermSet,
2829
xterms=xterms::TermSet,
30+
contrasts=get(args, :contrasts, nothing)::Union{Dict{Symbol,Any},Nothing},
2931
drop_singletons=get(args, :drop_singletons, true)::Bool,
3032
nfethreads=get(args, :nfethreads, Threads.nthreads())::Int,
31-
contrasts=get(args, :contrasts, nothing)::Union{Dict{Symbol,Any},Nothing},
3233
fetol=get(args, :fetol, 1e-8)::Float64,
3334
femaxiter=get(args, :femaxiter, 10000)::Int,
3435
cohortinteracted=get(args, :cohortinteracted, true)::Bool,
@@ -38,11 +39,11 @@ function valid_didargs(d::Type{Reg}, ::DynamicTreatment{SharpDesign},
3839
end
3940

4041
"""
41-
RegressionBasedDIDResult{TR<:AbstractTreatment, CohortInteracted} <: DIDResult
42+
RegressionBasedDIDResult{TR,CohortInteracted,Haslsweights} <: DIDResult{TR}
4243
4344
Estimation results from regression-based difference-in-differences.
4445
"""
45-
struct RegressionBasedDIDResult{TR<:AbstractTreatment, CohortInteracted} <: DIDResult
46+
struct RegressionBasedDIDResult{TR,CohortInteracted,Haslsweights} <: DIDResult{TR}
4647
coef::Vector{Float64}
4748
vcov::Matrix{Float64}
4849
vce::CovarianceEstimator
@@ -81,7 +82,8 @@ function result(::Type{Reg}, @nospecialize(nt::NamedTuple))
8182
cnames = _treatnames(nt.treatcells)
8283
cnames = append!(cnames, coefnames.(nt.xterms))[nt.basiscols]
8384
coefinds = Dict(cnames .=> 1:length(cnames))
84-
didresult = RegressionBasedDIDResult{typeof(nt.tr), nt.cohortinteracted}(
85+
didresult = RegressionBasedDIDResult{typeof(nt.tr),
86+
nt.cohortinteracted, nt.lsweights!==nothing}(
8587
nt.coef, nt.vcov_mat, nt.vce, nt.tr, nt.pr, nt.cellweights, nt.cellcounts,
8688
nt.esample, sum(nt.esample), nt.dof_resid, nt.F, nt.p,
8789
yname, cnames, coefinds, nt.treatcells, nt.treatname, nt.yxterms,
@@ -164,12 +166,12 @@ function show(io::IO, ::MIME"text/plain", r::RegressionBasedDIDResult;
164166
end
165167

166168
"""
167-
AggregatedRegBasedDIDResult{P<:RegressionBasedDIDResult, I} <: AggregatedDIDResult{P}
169+
AggregatedRegDIDResult{TR,Haslsweights,P<:RegressionBasedDIDResult,I} <: AggregatedDIDResult{TR,P}
168170
169171
Estimation results aggregated from a [`RegressionBasedDIDResult`](@ref).
170172
See also [`agg`](@ref).
171173
"""
172-
struct AggregatedRegBasedDIDResult{P<:RegressionBasedDIDResult, I} <: AggregatedDIDResult{P}
174+
struct AggregatedRegDIDResult{TR,Haslsweights,P<:RegressionBasedDIDResult,I} <: AggregatedDIDResult{TR,P}
173175
parent::P
174176
inds::I
175177
coef::Vector{Float64}
@@ -181,6 +183,9 @@ struct AggregatedRegBasedDIDResult{P<:RegressionBasedDIDResult, I} <: Aggregated
181183
coefinds::Dict{String, Int}
182184
treatcells::VecColumnTable
183185
lsweights::Union{TableIndexedMatrix{Float64, Matrix{Float64}, VecColumnTable, VecColumnTable}, Nothing}
186+
ycellmeans::Union{Vector{Float64}, Nothing}
187+
ycellweights::Union{Vector{Float64}, Nothing}
188+
ycellcounts::Union{Vector{Int}, Nothing}
184189
end
185190

186191
function agg(r::RegressionBasedDIDResult{<:DynamicTreatment}, names=nothing;
@@ -226,13 +231,183 @@ function agg(r::RegressionBasedDIDResult{<:DynamicTreatment}, names=nothing;
226231
lswtmat = view(r.lsweights.m, :, inds) * cweights
227232
lswt = TableIndexedMatrix(lswtmat, r.lsweights.r, tcells)
228233
end
229-
return AggregatedRegBasedDIDResult{typeof(r), typeof(inds)}(r, inds, cf, v, cweights,
230-
cellweights, cellcounts, cnames, coefinds, tcells, lswt)
234+
return AggregatedRegDIDResult{typeof(r.tr), lswt!==nothing, typeof(r), typeof(inds)}(
235+
r, inds, cf, v, cweights, cellweights, cellcounts, cnames, coefinds, tcells,
236+
lswt, r.ycellmeans, r.ycellweights, r.ycellcounts)
237+
end
238+
239+
vce(r::AggregatedRegDIDResult) = vce(parent(r))
240+
treatment(r::AggregatedRegDIDResult) = treatment(parent(r))
241+
nobs(r::AggregatedRegDIDResult) = nobs(parent(r))
242+
outcomename(r::AggregatedRegDIDResult) = outcomename(parent(r))
243+
weights(r::AggregatedRegDIDResult) = weights(parent(r))
244+
treatnames(r::AggregatedRegDIDResult) = coefnames(r)
245+
dof_residual(r::AggregatedRegDIDResult) = dof_residual(parent(r))
246+
247+
"""
248+
RegDIDResultOrAgg{TR,Haslsweights}
249+
250+
Union type of [`RegressionBasedDIDResult`](@ref) and [`AggregatedRegDIDResult`](@ref).
251+
"""
252+
const RegDIDResultOrAgg{TR,Haslsweights} =
253+
Union{RegressionBasedDIDResult{TR,<:Any,Haslsweights},
254+
AggregatedRegDIDResult{TR,Haslsweights}}
255+
256+
"""
257+
has_lsweights(r::RegDIDResultOrAgg)
258+
259+
Test whether `r` contains computed least-sqaure weights (`r.lsweights!==nothing`).
260+
"""
261+
has_lsweights(::RegDIDResultOrAgg{TR,H}) where {TR,H} = H
262+
263+
"""
264+
ContrastResult{T,M,R,C} <: AbstractMatrix{T}
265+
266+
Matrix type that holds least-square weights obtained from one or more
267+
[`RegDIDResultOrAgg`](@ref)s computed over the same set of cells
268+
and cell-level averages.
269+
See also [`contrast`](@ref).
270+
271+
The least-square weights are stored in a `Matrix` that can be retrieved
272+
with property name `:m`,
273+
where the weights for each treatment coefficient
274+
are stored columnwise starting from the second column and
275+
the first column contains the cell-level averages.
276+
The indices for cells can be accessed with property name `:r`;
277+
and indices for identifying the coefficients can be accessed with property name `:c`.
278+
The [`RegDIDResultOrAgg`](@ref)s used to generate the `ContrastResult`
279+
can be accessed by calling `parent`.
280+
"""
281+
struct ContrastResult{T,M,R,C} <: AbstractMatrix{T}
282+
rs::Vector{RegDIDResultOrAgg}
283+
m::TableIndexedMatrix{T,M,R,C}
284+
function ContrastResult(rs::Vector{RegDIDResultOrAgg},
285+
m::TableIndexedMatrix{T,M,R,C}) where {T,M,R,C}
286+
cnames = columnnames(m.c)
287+
cnames[1]==:iresult && cnames[2]==:icoef && cnames[3]==:name || throw(ArgumentError(
288+
"Table paired with column indices has unaccepted column names"))
289+
return new{T,M,R,C}(rs, m)
290+
end
291+
end
292+
293+
_getmat(cr::ContrastResult) = getfield(cr, :m)
294+
Base.size(cr::ContrastResult) = size(_getmat(cr))
295+
Base.getindex(cr::ContrastResult, i) = _getmat(cr)[i]
296+
Base.getindex(cr::ContrastResult, i, j) = _getmat(cr)[i,j]
297+
Base.IndexStyle(::Type{<:ContrastResult{T,M}}) where {T,M} = IndexStyle(M)
298+
Base.getproperty(cr::ContrastResult, n::Symbol) = getproperty(_getmat(cr), n)
299+
Base.parent(cr::ContrastResult) = getfield(cr, :rs)
300+
301+
"""
302+
contrast(r1::RegDIDResultOrAgg, rs::RegDIDResultOrAgg...)
303+
304+
Construct a [`ContrastResult`](@ref) by collecting the computed least-square weights
305+
from each of the [`RegDIDResultOrAgg`](@ref).
306+
"""
307+
function contrast(r1::RegDIDResultOrAgg, rs::RegDIDResultOrAgg...)
308+
has_lsweights(r1) && all(r->has_lsweights(r), rs) || throw(ArgumentError(
309+
"Results must contain computed least-sqaure weights"))
310+
ri = r1.lsweights.r
311+
ncoef = ntreatcoef(r1)
312+
m = r1.lsweights.m
313+
for r in rs
314+
r.lsweights.r == ri || throw(ArgumentError(
315+
"Cells for least-square weights comparisons must be identical across the inputs"))
316+
ncoef += ntreatcoef(r)
317+
end
318+
rs = RegDIDResultOrAgg[r1, rs...]
319+
m = hcat(r1.ycellmeans, (r.lsweights.m for r in rs)...)
320+
rinds = vcat(0, (fill(i+1, ntreatcoef(r)) for (i, r) in enumerate(rs))...)
321+
cinds = vcat(0, (1:ntreatcoef(r) for r in rs)...)
322+
names = vcat("cellmeans", (treatnames(r) for r in rs)...)
323+
ci = VecColumnTable((iresult=rinds, icoef=cinds, name=names))
324+
return ContrastResult(rs, TableIndexedMatrix(m, ri, ci))
325+
end
326+
327+
function Base.:(==)(x::ContrastResult, y::ContrastResult)
328+
# Assume no missing
329+
x.m == y.m || return false
330+
x.r == y.r || return false
331+
x.c == y.c || return false
332+
return parent(x) == parent(y)
333+
end
334+
335+
function Base.sort!(cr::ContrastResult; @nospecialize(kwargs...))
336+
p = sortperm(cr.r; kwargs...)
337+
@inbounds for col in cr.r
338+
col .= col[p]
339+
end
340+
@inbounds cr.m .= cr.m[p,:]
341+
return cr
342+
end
343+
344+
_parse_subset(cr::ContrastResult, by::Pair) = (inds = apply(cr.r, by); return inds)
345+
346+
function _parse_subset(cr::ContrastResult, inds)
347+
eltype(inds) <: Pair || return inds
348+
inds = apply_and(cr.r, inds...)
349+
return inds
231350
end
232351

233-
vce(r::AggregatedRegBasedDIDResult) = vce(parent(r))
234-
nobs(r::AggregatedRegBasedDIDResult) = nobs(parent(r))
235-
outcomename(r::AggregatedRegBasedDIDResult) = outcomename(parent(r))
236-
weights(r::AggregatedRegBasedDIDResult) = weights(parent(r))
237-
treatnames(r::AggregatedRegBasedDIDResult) = coefnames(r)
238-
dof_residual(r::AggregatedRegBasedDIDResult) = dof_residual(parent(r))
352+
_parse_subset(::ContrastResult, ::Colon) = Colon()
353+
354+
function Base.view(cr::ContrastResult, subset)
355+
inds = _parse_subset(cr, subset)
356+
r = view(cr.r, inds)
357+
m = view(cr.m, inds, :)
358+
return ContrastResult(parent(cr), TableIndexedMatrix(m, r, cr.c))
359+
end
360+
361+
function _checklengthmatch(v, name::String, N::Int)
362+
length(v) == N || throw(ArgumentError(
363+
"The length of $name ($(length(v))) does not match the number of rows of cr ($(N))"))
364+
end
365+
366+
_checklengthmatch(v::Nothing, name::String, N::Int) = nothing
367+
368+
"""
369+
post!(gl, gr, gd, ::StataPostHDF, cr::ContrastResult, left::Int=2, right::Int=3; kwargs...)
370+
371+
Export the least-square weights for coefficients indexed by `left` and `right`
372+
from `cr` for Stata module [`posthdf`](https://github.com/junyuan-chen/posthdf).
373+
The contribution of each cell to the difference between two coefficients
374+
are computed and also exported.
375+
The weights and contributions are stored as coefficient estimates
376+
in three groups `gl`, `gr` and `gd` respectively.
377+
The groups can be `HDF5.Group`s or objects that can be indexed by strings.
378+
379+
# Keywords
380+
- `lefttag::String=string(left)`: name to be used as `depvar` in Stata after being prefixed by `"l_"` for the coefficient indexed by `left`.
381+
- `righttag::String=string(right)`: name to be used as `depvar` in Stata after being prefixed by `"r_"` for the coefficient indexed by `right`.
382+
- `model::String="InteractionWeightedDIDs.ContrastResult"`: name of the model.
383+
- `eqnames::Union{AbstractVector, Nothing}=nothing`: equation names prefixed to coefficient names in Stata.
384+
- `colnames::Union{AbstractVector, Nothing}=nothing`: column names used as coefficient names in Stata.
385+
- `at::Union{AbstractVector{<:Real}, Nothing}=nothing`: the `at` vector in Stata.
386+
"""
387+
function post!(gl, gr, gd, ::StataPostHDF, cr::ContrastResult, left::Int=2, right::Int=3;
388+
lefttag::String=string(left), righttag::String=string(right),
389+
model::String="InteractionWeightedDIDs.ContrastResult",
390+
eqnames::Union{AbstractVector, Nothing}=nothing,
391+
colnames::Union{AbstractVector, Nothing}=nothing,
392+
at::Union{AbstractVector{<:Real}, Nothing}=nothing)
393+
N = size(cr.m, 1)
394+
_checklengthmatch(eqnames, "eqnames", N)
395+
_checklengthmatch(colnames, "colnames", N)
396+
_checklengthmatch(at, "at", N)
397+
gl["depvar"] = string("l_", lefttag)
398+
wtl = view(cr.m, :, left)[:]
399+
gl["b"] = wtl
400+
gr["depvar"] = string("r_", righttag)
401+
wtr = view(cr.m, :, right)[:]
402+
gr["b"] = wtr
403+
gd["depvar"] = string("d_", lefttag, "_", righttag)
404+
diff = (wtl.-wtr).*view(cr.m,:,1)[:]
405+
gd["b"] = diff
406+
colnames === nothing && (colnames = 1:N)
407+
cnames = eqnames === nothing ? string.(colnames) : string.(eqnames, ":", colnames)
408+
for g in (gl, gr, gd)
409+
g["model"] = model
410+
g["coefnames"] = cnames
411+
at === nothing || (g["at"] = at)
412+
end
413+
end

src/procedures.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ required(::EstVcov) = (:data, :esample, :vce, :coef, :X, :crossx, :residuals, :x
544544
"""
545545
solveleastsquaresweights(args...)
546546
547-
Solve the cell-level weights assigned by least-sqaures.
547+
Solve the cell-level weights assigned by least squares.
548548
See also [`SolveLeastSquaresWeights`](@ref).
549549
"""
550550
function solveleastsquaresweights(::DynamicTreatment{SharpDesign},
@@ -616,7 +616,7 @@ end
616616
SolveLeastSquaresWeights <: StatsStep
617617
618618
Call [`InteractionWeightedDIDs.solveleastsquaresweights`](@ref)
619-
to solve the cell-level weights assigned by least-sqaures.
619+
to solve the cell-level weights assigned by least squares.
620620
"""
621621
const SolveLeastSquaresWeights = StatsStep{:SolveLeastSquaresWeights,
622622
typeof(solveleastsquaresweights), true}

0 commit comments

Comments
 (0)