Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit bf568fd

Browse files
authored
Replace Terms with TermSet and allow grouping args by isequal (#14)
1 parent aa25ed2 commit bf568fd

File tree

11 files changed

+191
-99
lines changed

11 files changed

+191
-99
lines changed

src/DiffinDiffsBase.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using Tables: istable, getcolumn, columntable, columnnames
1212
import Base: ==, show, union
1313
import Base: eltype, firstindex, lastindex, getindex, iterate, length, sym_in
1414
import StatsBase: coef, vcov, responsename, coefnames, weights, nobs, dof_residual
15-
import StatsModels: termvars, hasintercept, omitsintercept
15+
import StatsModels: termvars
1616

1717
const TimeType = Int
1818

@@ -46,7 +46,7 @@ export cb,
4646
notyettreated,
4747
istreated,
4848

49-
Terms,
49+
TermSet,
5050
eachterm,
5151
TreatmentTerm,
5252
treat,
@@ -60,6 +60,7 @@ export cb,
6060
@specset,
6161

6262
CheckData,
63+
GroupTerms,
6364
CheckVars,
6465
MakeWeights,
6566

src/StatsProcedures.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""
2-
StatsStep{Alias, F<:Function}
2+
StatsStep{Alias, F<:Function, ById}
33
44
Specify the function for moving a step in an [`AbstractStatsProcedure`](@ref).
55
An instance of `StatsStep` is callable.
66
77
# Parameters
88
- `Alias::Symbol`: alias of the type for pretty-printing.
99
- `F<:Function`: type of the function to be called by `StatsStep`.
10+
- `ById::Bool`: whether arguments from multiple [`StatsSpec`](@ref)s should be grouped by `object-id` or `isequal`.
1011
1112
# Methods
1213
(step::StatsStep{A,F})(ntargs::NamedTuple; verbose::Bool=false)
@@ -23,9 +24,10 @@ in case both are specified.
2324
## Returns
2425
- `NamedTuple`: named intermediate results.
2526
"""
26-
struct StatsStep{Alias, F<:Function} end
27+
struct StatsStep{Alias, F<:Function, ById} end
2728

2829
_f(::StatsStep{A,F}) where {A,F} = F.instance
30+
_byid(::StatsStep{A,F,I}) where {A,F,I} = I
2931

3032
"""
3133
required(s::StatsStep)
@@ -184,6 +186,7 @@ end
184186

185187
_sharedby(s::SharedStatsStep) = s.ids
186188
_f(s::SharedStatsStep) = _f(s.step)
189+
_byid(s::SharedStatsStep) = _byid(s.step)
187190
groupargs(s::SharedStatsStep, @nospecialize(ntargs::NamedTuple)) = groupargs(s.step, ntargs)
188191
combinedargs(s::SharedStatsStep, v::AbstractArray) = combinedargs(s.step, v)
189192
copyargs(s::SharedStatsStep) = copyargs(s.step)
@@ -476,17 +479,19 @@ function proceed(sps::Vector{<:StatsSpec};
476479
end
477480

478481
steps = pool((p for p in keys(gids))...)
479-
tasks = IdDict{Tuple, Vector{Int}}()
482+
tasks_byid = IdDict{Tuple, Vector{Int}}()
483+
tasks_byeq = Dict{Tuple, Vector{Int}}()
480484
ntask_total = 0
481485
step_count = 0
482486
paused = false
483487
@inbounds for step in steps
488+
tasks = _byid(step) ? tasks_byid : tasks_byeq
484489
ntask = 0
485490
verbose && print("Running ", step, "...")
486-
# Group arguments by ===
491+
# Group arguments by objectid or isequal
487492
for i in _sharedby(step)
488-
taskids = gids[steps.procs[i]]
489-
for j in taskids
493+
traceids = gids[steps.procs[i]]
494+
for j in traceids
490495
push!(get!(Vector{Int}, tasks, groupargs(step, traces[j])), j)
491496
end
492497
end

src/did.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,22 @@ _key(::AbstractTreatment) = :tr
1818
_key(::AbstractParallel) = :pr
1919
_key(::Any) = throw(ArgumentError("unacceptable positional arguments"))
2020

21+
function _totermset!(args::Dict{Symbol,Any}, s::Symbol)
22+
if haskey(args, s) && !(args[s] isa TermSet)
23+
ts = TermSet()
24+
foreach(t->setindex!(ts, nothing, t), args[s])
25+
args[s] = ts
26+
end
27+
end
28+
2129
"""
2230
parse_didargs!(args::Vector{Any}, kwargs::Dict{Symbol,Any})
2331
2432
Return a `Dict` that is suitable for being passed to
2533
[`valid_didargs`](@ref) for further processing.
2634
2735
Any [`TreatmentTerm`](@ref) or [`FormulaTerm`](@ref) in `args` is decomposed.
36+
Any collection of terms is converted to `TermSet`.
2837
Keys are assigned to all positional arguments based on their types.
2938
An optional `name` for [`StatsSpec`](@ref) can be included in `args` as a string.
3039
The order of positional arguments is irrelevant.
@@ -39,8 +48,8 @@ function parse_didargs!(args::Vector{Any}, kwargs::Dict{Symbol,Any})
3948
kwargs[_key(treat.pr)] = treat.pr
4049
kwargs[:yterm] = arg.lhs
4150
kwargs[:treatname] = treat.sym
42-
intacts==() || (kwargs[:treatintterms] = intacts)
43-
xs==() || (kwargs[:xterms] = xs)
51+
kwargs[:treatintterms] = intacts
52+
kwargs[:xterms] = xs
4453
elseif arg isa TreatmentTerm
4554
kwargs[_key(arg.tr)] = arg.tr
4655
kwargs[_key(arg.pr)] = arg.pr
@@ -49,6 +58,7 @@ function parse_didargs!(args::Vector{Any}, kwargs::Dict{Symbol,Any})
4958
kwargs[_key(arg)] = arg
5059
end
5160
end
61+
foreach(n->_totermset!(kwargs, n), (:treatintterms, :xterms))
5262
return kwargs
5363
end
5464

src/procedures.jl

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Check `data` is a `Table` and find valid rows for options `subset` and `weightname`.
55
See also [`CheckData`](@ref).
66
"""
7-
function checkdata(data, subset::Union{AbstractVector, Nothing},
7+
function checkdata(data, subset::Union{BitVector, Nothing},
88
weightname::Union{Symbol, Nothing})
99

1010
istable(data) ||
@@ -13,7 +13,7 @@ function checkdata(data, subset::Union{AbstractVector, Nothing},
1313
if subset !== nothing
1414
length(subset) != size(data, 1) &&
1515
throw(DimensionMismatch("`data` of $(size(data, 1)) rows
16-
cannot be matched with `subset` vector of $(length(subset)) elements"))
16+
cannot be matched with subset vector of $(length(subset)) elements"))
1717
esample = .!ismissing.(subset) .& subset
1818
else
1919
esample = trues(size(data, 1))
@@ -34,11 +34,34 @@ end
3434
Call [`DiffinDiffsBase.checkdata`](@ref)
3535
for some preliminary checks of the input data.
3636
"""
37-
const CheckData = StatsStep{:CheckData, typeof(checkdata)}
37+
const CheckData = StatsStep{:CheckData, typeof(checkdata), true}
3838

3939
required(::CheckData) = (:data,)
4040
default(::CheckData) = (subset=nothing, weightname=nothing)
4141

42+
"""
43+
groupterms(args...)
44+
45+
Return the arguments for allowing later comparisons based on object-id.
46+
See also [`GroupTerms`](@ref).
47+
"""
48+
groupterms(treatintterms::TermSet, xterms::TermSet) =
49+
(treatintterms = treatintterms, xterms = xterms)
50+
51+
"""
52+
GroupTerms <: StatsStep
53+
54+
Call [`DiffinDiffsBase.groupterms`](@ref)
55+
to obtain one of the instances of `treatintterms` and `xterms`
56+
that have been grouped by `==`
57+
for allowing later comparisons based on object-id.
58+
59+
This step is only useful when working with [`@specset`](@ref) and [`proceed`](@ref).
60+
"""
61+
const GroupTerms = StatsStep{:GroupTerms, typeof(groupterms), false}
62+
63+
required(::GroupTerms) = (:treatintterms, :xterms)
64+
4265
function _overlaptime(tr::DynamicTreatment, tr_rows::BitVector, data)
4366
control_time = Set(view(getcolumn(data, tr.time), .!tr_rows))
4467
treated_time = Set(view(getcolumn(data, tr.time), tr_rows))
@@ -76,7 +99,7 @@ See also [`CheckVars`](@ref).
7699
"""
77100
function checkvars!(data, tr::AbstractTreatment, pr::AbstractParallel,
78101
yterm::AbstractTerm, treatname::Symbol, esample::BitVector,
79-
treatintterms::Terms, xterms::Terms)
102+
treatintterms::TermSet, xterms::TermSet)
80103

81104
treatvars = union([treatname], (termvars(t) for t in (tr, pr, treatintterms))...)
82105
for v in treatvars
@@ -105,10 +128,10 @@ end
105128
106129
Call [`DiffinDiffsBase.checkvars!`](@ref) to exclude invalid rows for relevant variables.
107130
"""
108-
const CheckVars = StatsStep{:CheckVars, typeof(checkvars!)}
131+
const CheckVars = StatsStep{:CheckVars, typeof(checkvars!), true}
109132

110133
required(::CheckVars) = (:data, :tr, :pr, :yterm, :treatname, :esample)
111-
default(::CheckVars) = (treatintterms=(), xterms=())
134+
default(::CheckVars) = (treatintterms=TermSet(), xterms=TermSet())
112135
copyargs(::CheckVars) = (6,)
113136

114137
"""
@@ -133,7 +156,7 @@ end
133156
134157
Call [`DiffinDiffsBase.makeweights`](@ref) to create a generic `Weights` vector.
135158
"""
136-
const MakeWeights = StatsStep{:MakeWeights, typeof(makeweights)}
159+
const MakeWeights = StatsStep{:MakeWeights, typeof(makeweights), true}
137160

138161
required(::MakeWeights) = (:data, :esample)
139162
default(::MakeWeights) = (weightname=nothing,)

src/terms.jl

Lines changed: 68 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
const Terms{N} = NTuple{N, AbstractTerm} where N
2+
const TermSet = IdDict{AbstractTerm, Nothing}
23

34
"""
45
eachterm(t)
@@ -68,72 +69,95 @@ Extract terms related to treatment specifications from `formula`.
6869
6970
# Returns
7071
- `TreatmentTerm`: the unique `TreatmentTerm` contained in the `formula`.
71-
- `Terms`: a tuple of any term that is interacted with the `TreatmentTerm`.
72-
- `Terms`: a tuple of remaining terms in `formula.rhs`.
72+
- `TermSet`: a set of terms that are interacted with the `TreatmentTerm`.
73+
- `TermSet`: a set of remaining terms in `formula.rhs`.
7374
7475
Error will be raised if either existence or uniqueness of the `TreatmentTerm` is violated.
7576
"""
7677
function parse_treat(@nospecialize(formula::FormulaTerm))
77-
# Use Array for detecting duplicate terms
78-
treats = Pair{TreatmentTerm,Tuple}[]
78+
mettreat = false
79+
treatterm = nothing
80+
ints = TermSet()
81+
xterms = TermSet()
7982
for term in eachterm(formula.rhs)
80-
if hastreat(term)
81-
if term isa TreatmentTerm
82-
push!(treats, term=>())
83-
elseif term isa FunctionTerm
84-
push!(treats, treat(term.args_parsed...)=>())
85-
elseif term isa InteractionTerm
86-
trs = []
87-
ints = []
83+
if term isa TreatmentTerm
84+
if !mettreat
85+
mettreat = true
86+
treatterm = term
87+
else
88+
throw(ArgumentError("cannot accept more than one TreatmentTerm"))
89+
end
90+
elseif term isa FunctionTerm{typeof(treat)}
91+
if !mettreat
92+
mettreat = true
93+
treatterm = treat(term.args_parsed...)
94+
else
95+
throw(ArgumentError("cannot accept more than one TreatmentTerm"))
96+
end
97+
elseif term isa InteractionTerm
98+
if hastreat(term)
99+
!mettreat ||
100+
throw(ArgumentError("cannot accept more than one TreatmentTerm"))
88101
for t in term.terms
89-
if hastreat(t)
90-
if t isa TreatmentTerm
91-
push!(trs, t)
92-
elseif t isa FunctionTerm
93-
push!(trs, treat(t.args_parsed...))
102+
if t isa TreatmentTerm
103+
if !mettreat
104+
mettreat = true
105+
treatterm = t
106+
else
107+
throw(ArgumentError("cannot accept more than one TreatmentTerm"))
108+
end
109+
elseif t isa FunctionTerm{typeof(treat)}
110+
if !mettreat
111+
mettreat = true
112+
treatterm = treat(t.args_parsed...)
113+
else
114+
throw(ArgumentError("cannot accept more than one TreatmentTerm"))
94115
end
95116
else
96-
push!(ints, t)
117+
ints[t] = nothing
97118
end
98119
end
99-
if length(trs)!=1
100-
throw(ArgumentError("invlid term $term in formula.
101-
An interaction term may contain at most one instance of `TreatmentTerm`."))
102-
else
103-
push!(treats, trs[1]=>Tuple(ints))
104-
end
120+
else
121+
xterms[term] = nothing
105122
end
123+
else
124+
xterms[term] = nothing
106125
end
107126
end
108-
length(treats)>1 &&
109-
throw(ArgumentError("cannot accept more than one `TreatmentTerm`."))
110-
isempty(treats) &&
111-
throw(ArgumentError("no `TreatmentTerm` is found."))
112-
xterms = Tuple(term for term in eachterm(formula.rhs) if !hastreat(term))
113-
return treats[1][1], treats[1][2], xterms
127+
mettreat || throw(ArgumentError("no TreatmentTerm is found"))
128+
return treatterm::TreatmentTerm, ints, xterms
114129
end
115130

116-
# A tentative solution to changes made in StatsModels v0.6.21
117-
hasintercept(::Tuple{}) = false
118-
omitsintercept(::Tuple{}) = false
119-
120131
isintercept(t::AbstractTerm) = t in (InterceptTerm{true}(), ConstantTerm(1))
121132
isomitsintercept(t::AbstractTerm) =
122133
t in (InterceptTerm{false}(), ConstantTerm(0), ConstantTerm(-1))
123134

124135
"""
125-
parse_intercept(ts::Terms)
136+
parse_intercept(ts::TermSet)
126137
127-
Convert any `ConstantTerm` to `InterceptTerm` and add them to the end of the tuple.
128-
This is useful for obtaining a unique way of specifying the intercept
138+
Convert any `ConstantTerm` to `InterceptTerm`
139+
and return Boolean values indicating whether terms explictly requiring
140+
including/excluding the intercept exist.
141+
142+
This function is useful for obtaining a unique way of specifying the intercept
129143
before going through the `schema`--`apply_schema` pipeline defined in `StatsModels`.
130144
"""
131-
function parse_intercept(@nospecialize(ts::Terms))
132-
out = AbstractTerm[t for t in ts if !(isintercept(t) || isomitsintercept(t))]
133-
omitsintercept(ts) && push!(out, InterceptTerm{false}())
134-
# This order is assumed by InteractionWeightedDIDs.Fstat
135-
hasintercept(ts) && push!(out, InterceptTerm{true}())
136-
return (out...,)
145+
function parse_intercept!(ts::TermSet)
146+
hasintercept = false
147+
hasomitsintercept = false
148+
for t in keys(ts)
149+
if isintercept(t)
150+
delete!(ts, t)
151+
hasintercept = true
152+
end
153+
if isomitsintercept(t)
154+
delete!(ts, t)
155+
hasomitsintercept = true
156+
end
157+
end
158+
hasintercept && (ts[InterceptTerm{true}()] = nothing)
159+
hasomitsintercept && (ts[InterceptTerm{false}()] = nothing)
160+
return hasintercept, hasomitsintercept
137161
end
138162

139-
termvars(::Tuple{}) = Symbol[]
163+
termvars(ts::TermSet) = mapreduce(termvars, union, keys(ts), init=Symbol[])

test/StatsProcedures.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,29 @@ using DiffinDiffsBase: _f, _get, groupargs,
33
import DiffinDiffsBase: required, default, transformed, combinedargs, copyargs
44

55
testvoidstep(a::String) = NamedTuple()
6-
const TestVoidStep = StatsStep{:TestVoidStep, typeof(testvoidstep)}
6+
const TestVoidStep = StatsStep{:TestVoidStep, typeof(testvoidstep), true}
77
required(::TestVoidStep) = (:a,)
88

99
testregstep(a::String, b::String) = (c=a*b,)
10-
const TestRegStep = StatsStep{:TestRegStep, typeof(testregstep)}
10+
const TestRegStep = StatsStep{:TestRegStep, typeof(testregstep), true}
1111
default(::TestRegStep) = (a="a", b="b")
1212

1313
testlaststep(a::String, c::String) = (result=a*c,)
14-
const TestLastStep = StatsStep{:TestLastStep, typeof(testlaststep)}
14+
const TestLastStep = StatsStep{:TestLastStep, typeof(testlaststep), true}
1515
default(::TestLastStep) = (a="a",)
1616
transformed(::TestLastStep, ntargs::NamedTuple) = (ntargs.c,)
1717

1818
testcombinestep(a::String, bs::String...) = (c=collect(bs),)
19-
const TestCombineStep = StatsStep{:TestCombineStep, typeof(testcombinestep)}
19+
const TestCombineStep = StatsStep{:TestCombineStep, typeof(testcombinestep), true}
2020
default(::TestCombineStep) = (a="a",)
2121
combinedargs(::TestCombineStep, ntargs) = [nt.b for nt in ntargs]
2222

2323
testarraystep(a::String, c::Array) = (result=c,)
24-
const TestArrayStep = StatsStep{:TestArrayStep, typeof(testarraystep)}
24+
const TestArrayStep = StatsStep{:TestArrayStep, typeof(testarraystep), true}
2525
required(::TestArrayStep) = (:a, :c)
2626
copyargs(::TestArrayStep) = (2,)
2727

28-
const TestUnnamedStep = StatsStep{:TestUnnamedStep, typeof(testregstep)}
28+
const TestUnnamedStep = StatsStep{:TestUnnamedStep, typeof(testregstep), true}
2929

3030
@testset "StatsStep" begin
3131
@testset "_get" begin

0 commit comments

Comments
 (0)