Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit 11ff4aa

Browse files
authored
Add functions _treatnames and treatnames (#8)
1 parent a4d3f4a commit 11ff4aa

File tree

4 files changed

+45
-9
lines changed

4 files changed

+45
-9
lines changed

src/DiffinDiffsBase.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using Reexport
88
using SplitApplyCombine: groupfind, groupview
99
using StatsBase: Weights, uweights
1010
@reexport using StatsModels
11-
using Tables: istable, getcolumn, columntable
11+
using Tables: istable, getcolumn, columntable, columnnames
1212

1313
import Base: ==, show, union
1414
import Base: eltype, firstindex, lastindex, getindex, iterate, length, sym_in
@@ -68,7 +68,8 @@ export cb,
6868
didspec,
6969
@did,
7070
DIDResult,
71-
outcomename
71+
outcomename,
72+
treatnames
7273

7374
include("utils.jl")
7475
include("treatments.jl")

src/did.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,27 @@ All concrete subtypes of `DIDResult` are expected to have the following fields:
174174
"""
175175
abstract type DIDResult <: StatisticalModel end
176176

177+
"""
178+
_treatnames(treatinds)
179+
180+
Generate names for treatment coefficients.
181+
Assume `treatinds` is compatible with the `Tables.jl` interface.
182+
"""
183+
function _treatnames(treatinds)
184+
cols = columnnames(treatinds)
185+
ncol = length(cols)
186+
# Assume treatinds has at least one column
187+
c1 = cols[1]
188+
names = Ref(string(c1, ": ")).*string.(getcolumn(treatinds, c1))
189+
if ncol > 1
190+
for i in 2:ncol
191+
ci = cols[i]
192+
names .*= Ref(string(" & ", ci, ": ")).*string.(getcolumn(treatinds, ci))
193+
end
194+
end
195+
return names
196+
end
197+
177198
"""
178199
coef(r::DIDResult)
179200
@@ -297,6 +318,13 @@ Return a vector of coefficient names.
297318
"""
298319
coefnames(r::DIDResult) = r.coefnames
299320

321+
"""
322+
treatnames(r::DIDResult)
323+
324+
Return a vector of names for treatment coefficients.
325+
"""
326+
treatnames(r::DIDResult) = r.coefnames[1:size(r.treatinds,1)]
327+
300328
"""
301329
weights(r::DIDResult)
302330

test/did.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,15 @@ end
153153
@test sp6 === @did [noproceed] TestDID @formula(y ~ treat(g, ttreat(t, 0), tpara(0)) & z + x)
154154
end
155155

156+
@testset "_treatnames" begin
157+
t = Table((rel=[1, 2],))
158+
@test _treatnames(t) == ["rel: 1", "rel: 2"]
159+
r = TestResult(2, 2)
160+
@test _treatnames(r.treatinds) == ["rel: $a & c: $b" for a in 1:2 for b in 1:2]
161+
end
162+
156163
@testset "DIDResult" begin
157-
r = result(TestDID, NamedTuple()).result
164+
r = TestResult(2, 2)
158165

159166
@test coef(r) == r.coef
160167
@test coef(r, 1) == 1
@@ -185,6 +192,7 @@ end
185192
@test responsename(r) == "y"
186193
@test outcomename(r) == responsename(r)
187194
@test coefnames(r) == r.coefnames
195+
@test treatnames(r) == r.coefnames[1:4]
188196
@test weights(r) == :w
189197
end
190198

test/runtests.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ using DiffinDiffsBase
44
using DataFrames
55
using DiffinDiffsBase: @fieldequal, unpack, @unpack, hastreat, parse_treat,
66
hasintercept, omitsintercept, isintercept, isomitsintercept, parse_intercept,
7-
_f, groupargs, pool, checkdata, checkvars!, makeweights, _getsubcolumns, parse_didargs
7+
_f, groupargs, pool, checkdata, checkvars!, makeweights, _getsubcolumns, parse_didargs,
8+
_treatnames
89
using StatsBase: Weights, UnitWeights
910
using StatsModels: termvars
1011
using TypedTables: Table
@@ -26,9 +27,7 @@ const tests = [
2627

2728
printstyled("Running tests:\n", color=:blue, bold=true)
2829

29-
for test in tests
30-
@time begin
31-
include("$test.jl")
32-
println("\033[1m\033[32mPASSED\033[0m: $(test)")
33-
end
30+
@time for test in tests
31+
include("$test.jl")
32+
println("\033[1m\033[32mPASSED\033[0m: $(test)")
3433
end

0 commit comments

Comments
 (0)