Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit 922b28b

Browse files
committed
Add more tests and fix some issues
1 parent 760cc2e commit 922b28b

File tree

10 files changed

+138
-41
lines changed

10 files changed

+138
-41
lines changed

src/StatsProcedures.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ An instance of `StatsStep` is callable.
1717
Call an instance of function of type `F` with arguments from `ntargs`
1818
formed by accessing the keys in `S` and `T` sequentially.
1919
20-
If a keyword argument `verbose` takes `true`
21-
or `ntargs` contains a key-value pair `verbose=true`,
22-
a message with the name of the `StatsStep` is printed to `stdout`.
20+
A message with the name of the `StatsStep` is printed to `stdout`
21+
if a keyword `verbose` takes the value `true`
22+
or `ntargs` contains a key-value pair `verbose=true`.
23+
The value from `ntargs` supersedes the keyword argument
24+
in case both are specified.
2325
2426
## Returns
2527
- `NamedTuple`: named intermidiate results.
@@ -31,8 +33,8 @@ _specnames(::StatsStep{A,F,S}) where {A,F,S} = S
3133
_tracenames(::StatsStep{A,F,S,T}) where {A,F,S,T} = T
3234

3335
function (step::StatsStep{A,F,S,T})(ntargs::NamedTuple; verbose::Bool=false) where {A,F,S,T}
34-
verbose || (haskey(ntargs, :verbose) && ntargs.verbose) &&
35-
println(" Running ", step)
36+
haskey(ntargs, :verbose) && (verbose = ntargs.verbose)
37+
verbose && printstyled("Running ", step, "\n", color=:green)
3638
args = NamedTuple{(S...,T...)}(ntargs)
3739
ret = F.instance(args...)
3840
if ret isa NamedTuple
@@ -323,7 +325,7 @@ while ignoring the orders.
323325
function (sp::StatsSpec{A,T})(;
324326
verbose::Bool=false, keep=nothing, keepall::Bool=false) where {A,T}
325327
args = verbose ? merge(sp.args, (verbose=true,)) : sp.args
326-
ntall = foldl(|>, T(), init=sp.args)
328+
ntall = foldl(|>, T(), init=args)
327329
if keepall
328330
return ntall
329331
elseif !isempty(ntall)
@@ -336,7 +338,7 @@ function (sp::StatsSpec{A,T})(;
336338
elseif eltype(keep) != Symbol
337339
throw(ArgumentError("expect Symbol or collections of Symbols for the value of option `keep`"))
338340
end
339-
return (result=res, (kv for kv in pairs(ntall) if kv[1] in keep)...)
341+
return ((kv for kv in pairs(ntall) if kv[1] in keep)..., result=res)
340342
end
341343
else
342344
return nothing
@@ -361,7 +363,7 @@ function run_specset(sps::AbstractVector{<:StatsSpec};
361363
gids = groupfind(r->procedure(r.spec), tb)
362364
steps = pool((p() for p in keys(gids))...)
363365
for step in steps
364-
verbose && println(" Running ", step)
366+
verbose && printstyled("Running ", step, "\n", color=:green)
365367
args = _specnames(step)
366368
tras = _tracenames(step)
367369
byf = r->merge(NamedTuple{args}(r.spec.args), NamedTuple{tras}(r.trace))

src/parallels.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,13 @@ A wrapper method of `notyettreated` for working with `@formula`.
243243
"""
244244
@unpack notyettreated
245245

246-
termvars(c::ParallelCondition) = Symbol[]
247-
termvars(s::ParallelStrength) = Symbol[]
248-
termvars(pr::AbstractParallel) = union(termvars(pr.c), termvars(pr.s))
246+
termvars(c::ParallelCondition) =
247+
error("StatsModels.termvars is not defined for $(typeof(c))")
248+
termvars(::Unconditional) = Symbol[]
249+
termvars(s::ParallelStrength) =
250+
error("StatsModels.termvars is not defined for $(typeof(s))")
251+
termvars(::Exact) = Symbol[]
252+
termvars(pr::AbstractParallel) =
253+
error("StatsModels.termvars is not defined for $(typeof(pr))")
254+
termvars(pr::NeverTreatedParallel) = union(termvars(pr.c), termvars(pr.s))
255+
termvars(pr::NotYetTreatedParallel) = union(termvars(pr.c), termvars(pr.s))

src/procedures.jl

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,45 @@
1+
"""
2+
checkdata(args...)
13
2-
function check_data(data, tr::AbstractTreatment, pr::AbstractParallel,
4+
Check `data` is a `Table` and find rows with nonmissing values for variables.
5+
See also [`CheckData`](@ref).
6+
7+
# Returns
8+
- `vars::Vector{Symbol}`: column names of relevant variables.
9+
- `esample::BitArray`: Boolean indices for rows with nonmissing values for the variables.
10+
"""
11+
function checkdata(data, tr::AbstractTreatment, pr::AbstractParallel,
312
yterm::AbstractTerm, treatname::Symbol, xterms::TupleTerm,
413
weights::Union{Symbol, Nothing}, subset::Union{AbstractVector, Nothing})
514

615
istable(data) ||
716
throw(ArgumentError("expected data in a Table, got $(typeof(data))"))
817

9-
vars = union(treatname, (termvars(t) for t in (tr, pr, yterm, xterms)))
18+
vars = union([treatname], (termvars(t) for t in (tr, pr, yterm, xterms))...)
1019
esample = BitArray(all(v->!ismissing(getproperty(row, v)), vars) for row in rows(data))
1120

12-
if subset != nothing
21+
if subset !== nothing
1322
length(subset) != size(data, 1) &&
14-
throw("df has $(size(df, 1)) rows but the subset vector has $(length(subset)) elements")
15-
esample .&= .!ismissing.(x) .& x
23+
throw(DimensionMismatch("`data` of $(size(data, 1)) rows cannot be matched with `subset` vector of $(length(subset)) elements"))
24+
esample .&= .!ismissing.(subset) .& subset
1625
end
1726

18-
if weights != nothing
27+
if weights !== nothing
1928
colweights = getcolumn(columns(data), weights)
2029
esample .&= .!ismissing.(colweights) .& (colweights .> 0)
2130
end
2231

23-
sum(esample) == 0 && throw(ArgumentError("no nonmissing data"))
32+
sum(esample) == 0 && error("no nonmissing data")
2433

25-
return (vars=vars, esample=esample,)
34+
return (vars=vars, esample=esample)
2635
end
2736

28-
const CheckData = StatsStep{:CheckData, typeof(check_data), (:data, :tr, :pr, :yterm, :treatname, :xterms, :weights, :subset), ()}
37+
"""
38+
CheckData
2939
40+
A [`StatsStep`](@ref) that calls [`DiffinDiffsBase.checkdata`](@ref)
41+
for some preliminary checks of the input data.
42+
"""
43+
const CheckData = StatsStep{:CheckData, typeof(checkdata), (:data, :tr, :pr, :yterm, :treatname, :xterms, :weights, :subset), ()}
3044

3145

src/treatments.jl

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ abstract type AbstractTreatment end
4141
@fieldequal AbstractTreatment
4242

4343
"""
44-
DynamicTreatment{E<:EleOrVec{<:Integer},S<:TreatmentSharpness} <: AbstractTreatment
44+
DynamicTreatment{E<:EleOrVec{<:Integer}, S<:TreatmentSharpness} <: AbstractTreatment
4545
4646
Specify an absorbing binary treatment with effects allowed to evolve over time.
4747
See also [`dynamic`](@ref).
@@ -51,18 +51,14 @@ See also [`dynamic`](@ref).
5151
- `exc::E`: excluded relative time (either an integer or vector of integers).
5252
- `s::S`: a [`TreatmentSharpness`](@ref).
5353
"""
54-
struct DynamicTreatment{E<:EleOrVec{<:Integer},S<:TreatmentSharpness} <: AbstractTreatment
54+
struct DynamicTreatment{E<:EleOrVec{<:Integer}, S<:TreatmentSharpness} <: AbstractTreatment
5555
time::Symbol
5656
exc::E
5757
s::S
58-
function DynamicTreatment(time::Symbol, exc::E, s::S) where
59-
{E<:EleOrVec{<:Integer},S<:TreatmentSharpness}
60-
if length(exc) > 1
61-
exc = sort!(exc)
62-
elseif E <: Vector
63-
exc = exc[1]
64-
end
65-
return new{typeof(exc),S}(time, exc, s)
58+
function DynamicTreatment(time::Symbol, exc, s::TreatmentSharpness)
59+
exc = unique!(sort!([exc...]))
60+
length(exc)==1 && (exc = exc[1])
61+
return new{typeof(exc),typeof(s)}(time, exc, s)
6662
end
6763
end
6864

@@ -105,7 +101,7 @@ Sharp dynamic treatment:
105101
```
106102
"""
107103
dynamic(time::Symbol, exc, s::TreatmentSharpness=sharp()) =
108-
DynamicTreatment(time, [exc...], s)
104+
DynamicTreatment(time, exc, s)
109105

110106
"""
111107
dynamic(ts::AbstractTerm...)
@@ -114,6 +110,9 @@ A wrapper method of `dynamic` for working with `@formula`.
114110
"""
115111
@unpack dynamic
116112

117-
termvars(s::TreatmentSharpness) = Symbol[]
118-
termvars(tr::AbstractTreatment) = termvars(tr.s)
119-
termvars(tr::DynamicTreatment) = pushfirst!(termvars(s::TreatmentSharpness), tr.time)
113+
termvars(s::TreatmentSharpness) =
114+
error("StatsModels.termvars is not defined for $(typeof(s))")
115+
termvars(::SharpDesign) = Symbol[]
116+
termvars(tr::AbstractTreatment) =
117+
error("StatsModels.termvars is not defined for $(typeof(tr))")
118+
termvars(tr::DynamicTreatment) = [tr.time, termvars(tr.s)...]

test/StatsProcedures.jl

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using DiffinDiffsBase: _f, _specnames, _tracenames,
2-
_sharedby
2+
_sharedby, _show_args
33

4-
testvoidstep(a::String, b::String) = println(a)
4+
testvoidstep(a::String, b::String) = nothing
55
const TestVoidStep = StatsStep{:TestVoidStep, typeof(testvoidstep), (:a, :b), ()}
66

77
testregstep(a::String, b::String) = (b=a, c=a*b,)
@@ -192,3 +192,32 @@ end
192192
NullProcedure"""
193193
end
194194

195+
@testset "StatsSpec" begin
196+
s1 = StatsSpec("name", RP, (a="a",b="b"))
197+
s2 = StatsSpec("", RP, (a="a",b="b"))
198+
s3 = StatsSpec("", UP, (a="a",b="b"))
199+
s4 = StatsSpec("name", RP, (b="b", a="a"))
200+
s5 = StatsSpec("name", RP, (b="b", a="a", d="d"))
201+
@test s1 == s2
202+
@test s2 != s3
203+
@test s2 != s4
204+
@test s2 s4
205+
206+
@test s1() == "aab"
207+
@test s3() == "ab"
208+
209+
@test s1(keep=:a) == (a="a", result="aab")
210+
@test s1(keep=(:a,:c)) == (a="a", c="ab", result="aab")
211+
@test_throws ArgumentError s1(keep=1)
212+
@test s1(keepall=true) == (a="a", b="a", c="ab", result="aab")
213+
214+
s6 = StatsSpec("", NP, NamedTuple())
215+
@test s6() === nothing
216+
217+
@test sprint(show, s1) == "name"
218+
@test sprint(show, s2) == "unnamed"
219+
@test sprint(show, MIME("text/plain"), s1) == "name (StatsSpec for RegProcedure)"
220+
@test sprint(show, MIME("text/plain"), s2) == "unnamed (StatsSpec for RegProcedure)"
221+
222+
@test _show_args(stdout, s1) === nothing
223+
end

test/parallels.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,14 @@ end
170170
Treated since: [0, 1]"""
171171
end
172172
end
173+
174+
@testset "termvars" begin
175+
@test termvars(Unconditional()) == Symbol[]
176+
@test termvars(Exact()) == Symbol[]
177+
@test termvars(nevertreated(-1)) == Symbol[]
178+
@test termvars(notyettreated(5)) == Symbol[]
179+
180+
@test_throws ErrorException termvars(TestParaCondition())
181+
@test_throws ErrorException termvars(TestParaStrength())
182+
@test_throws ErrorException termvars(PR)
183+
end

test/procedures.jl

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,37 @@
1-
using DiffinDiffsBase: _f, _specnames, _tracenames,
2-
check_data
1+
using DiffinDiffsBase: _f, _specnames, _tracenames, checkdata
32

4-
@testset "check_data" begin
3+
@testset "CheckData" begin
54
@testset "StatsStep" begin
65
@test sprint(show, CheckData()) == "CheckData"
76
@test sprint(show, MIME("text/plain"), CheckData()) == """
8-
CheckData (StatsStep that calls DiffinDiffsBase.check_data):
7+
CheckData (StatsStep that calls DiffinDiffsBase.checkdata):
98
arguments from StatsSpec: (:data, :tr, :pr, :yterm, :treatname, :xterms, :weights, :subset)
109
arguments from trace: ()"""
1110

12-
@test _f(CheckData()) == check_data
11+
@test _f(CheckData()) == checkdata
1312
@test _specnames(CheckData()) == (:data, :tr, :pr, :yterm, :treatname, :xterms, :weights, :subset)
1413
@test _tracenames(CheckData()) == ()
1514
end
15+
16+
@testset "checkdata" begin
17+
hrs = exampledata("hrs")
18+
nt = (data=hrs, tr=dynamic(:wave, -1), pr=nevertreated(11),
19+
yterm=term(:oop_spend), treatname=:wave_hosp, xterms=(),
20+
weights=nothing, subset=nothing)
21+
@test checkdata(nt...) ==
22+
(vars=[:wave_hosp, :wave, :oop_spend], esample=trues(size(hrs,1)))
23+
24+
nt = merge(nt, (weights=:rwthh, subset=hrs.male))
25+
@test checkdata(nt...) ==
26+
(vars=[:wave_hosp, :wave, :oop_spend], esample=BitArray(hrs.male))
27+
28+
nt = merge(nt, (data=rand(10,10),))
29+
@test_throws ArgumentError checkdata(nt...)
30+
31+
nt = merge(nt, (data=hrs, subset=BitArray(hrs.male[1:100])))
32+
@test_throws DimensionMismatch checkdata(nt...)
33+
34+
nt = merge(nt, (subset=falses(size(hrs,1)),))
35+
@test_throws ErrorException checkdata(nt...)
36+
end
1637
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using Test
22
using DiffinDiffsBase
33

4+
using StatsModels: termvars
5+
46
include("testutils.jl")
57

68
const tests = [

test/testutils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ import Base: show
44

55
sprintcompact(x) = sprint(show, x; context=:compact=>true)
66

7+
struct TestSharpness <: TreatmentSharpness end
8+
struct TestParaCondition <: ParallelCondition end
9+
struct TestParaStrength <: ParallelStrength end
10+
711
struct TestTreatment <: AbstractTreatment
812
time::Symbol
913
ref::Int

test/treatments.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,11 @@ end
5656
excluded relative time: [-2, -1]"""
5757
end
5858
end
59+
60+
@testset "termvars" begin
61+
@test termvars(sharp()) == Symbol[]
62+
@test termvars(dynamic(:month, -1)) == [:month]
63+
64+
@test_throws ErrorException termvars(TestSharpness())
65+
@test_throws ErrorException termvars(TR)
66+
end

0 commit comments

Comments
 (0)