Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit cf3049c

Browse files
committed
Replace Dict with NamedTuple for StatsSpec
1 parent cc83160 commit cf3049c

File tree

6 files changed

+128
-58
lines changed

6 files changed

+128
-58
lines changed

src/DiffinDiffsBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ export @fieldequal,
1414
unpack,
1515
kwarg,
1616
@unpack,
17+
,
1718
exampledata,
1819
sprintcompact,
1920

src/did.jl

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ function did(d::Type{<:DiffinDiffsEstimator}, @nospecialize(formula::FormulaTerm
4343
yterm=formula.lhs, treatname=treat.sym, ints..., xterms..., kwargs...)
4444
end
4545

46-
argpair(arg::Type{<:DiffinDiffsEstimator}) = :d => arg
47-
argpair(arg::AbstractString) = :name => String(arg)
4846
argpair(arg::AbstractTreatment) = :tr => arg
4947
argpair(arg::AbstractParallel) = :pr => arg
5048
argpair(::Any) = throw(ArgumentError("unacceptable positional arguments"))
@@ -61,14 +59,20 @@ that can be accepted by [`did`](@ref).
6159
# Returns
6260
- `Type{<:DiffinDiffsEstimator}`: either the type found in positional arguments or `DefaultDID`.
6361
- `String`: either a string found in positional arguments or `""` if no instance of any subtype of `AbstractString` is found.
64-
- `Dict`: up to two key-value pairs for instances of [`AbstractTreatment`](@ref) and [`AbstractParallel`](@ref) with keys being `:tr` and `:pr`.
65-
- `Dict`: keyword arguments with possibly additional pairs after parsing positional arguments.
62+
- `NamedTuple`: contain at most one instance of [`AbstractTreatment`](@ref) and [`AbstractParallel`](@ref) with associated keys being `:tr` and `:pr` respectively.
63+
- `NamedTuple`: keyword arguments with possibly additional elements after parsing positional arguments.
6664
"""
6765
function parse_didargs(args...; kwargs...)
66+
sptypes = Type{<:DiffinDiffsEstimator}[]
67+
names = String[]
6868
pargs = Pair{Symbol,Any}[]
6969
pkwargs = Pair{Symbol,Any}[kwargs...]
7070
for arg in args
71-
if arg isa FormulaTerm
71+
if arg isa Type{<:DiffinDiffsEstimator}
72+
push!(sptypes, arg)
73+
elseif arg isa AbstractString
74+
push!(names, String(arg))
75+
elseif arg isa FormulaTerm
7276
treat, intacts, xs = parse_treat(arg)
7377
push!(pargs, argpair(treat.tr), argpair(treat.pr))
7478
push!(pkwargs, :yterm => arg.lhs, :treatname => treat.sym)
@@ -81,12 +85,22 @@ function parse_didargs(args...; kwargs...)
8185
push!(pargs, argpair(arg))
8286
end
8387
end
84-
args = Dict{Symbol,Any}(pargs...)
85-
kwargs = Dict{Symbol,Any}(pkwargs...)
86-
length(args) == length(pargs) && length(kwargs) == length(pkwargs) ||
88+
if length(sptypes) > 1 || length(names) > 1
8789
throw(ArgumentError("redundant arguments encountered"))
88-
sptype = pop!(args, :d, DefaultDID)
89-
name = pop!(args, :name, "")
90+
else
91+
keyargs = first.(pargs)
92+
if length(keyargs) != length(unique(keyargs))
93+
throw(ArgumentError("redundant arguments encountered"))
94+
else
95+
keykwargs = first.(pkwargs)
96+
length(keykwargs) != length(unique(keykwargs)) &&
97+
throw(ArgumentError("redundant arguments encountered"))
98+
end
99+
end
100+
sptype = isempty(sptypes) ? DefaultDID : sptypes[1]
101+
name = isempty(names) ? "" : names[1]
102+
args = (; pargs...)
103+
kwargs = (; pkwargs...)
90104
return sptype, name, args, kwargs
91105
end
92106

src/procedures.jl

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ iterate(p::AbstractStatsProcedure{T}, state=1) where T =
2020
state > length(p) ? nothing : (p[state], state+1)
2121

2222
"""
23-
StatsSpec{T<:AbstractStatsProcedure, IsComplete}
23+
StatsSpec{T<:AbstractStatsProcedure, IsValidated}
2424
2525
Record the specification for a statistical procedure of type `T`
26-
that may or may not be verified to be complete as indicated by `IsComplete`.
26+
that may or may not be verified to be valid as indicated by `IsValidated`.
2727
2828
The specification is recorded based on the arguments
2929
for a function that will conduct the procedure.
@@ -32,21 +32,40 @@ can be constructed solely based on the type of each argument.
3232
3333
# Fields
3434
- `name::String`: an optional name for the specification.
35-
- `args::Dict{Symbol}`: positional arguments indexed based on their types.
36-
- `kwargs::Dict{Symbol}`: keyword arguments.
35+
- `args::NamedTuple`: positional arguments indexed based on their types.
36+
- `kwargs::NamedTuple`: keyword arguments.
3737
"""
38-
struct StatsSpec{T<:AbstractStatsProcedure, IsComplete}
38+
struct StatsSpec{T<:AbstractStatsProcedure, IsValidated}
3939
name::String
40-
args::Dict{Symbol}
41-
kwargs::Dict{Symbol}
40+
args::NamedTuple
41+
kwargs::NamedTuple
4242
end
4343

4444
StatsSpec(T::Type{<:AbstractStatsProcedure}, name::String,
45-
args::Dict{Symbol}, kwargs::Dict{Symbol}, IsComplete::Bool=false) =
46-
StatsSpec{T,IsComplete}(name, args, kwargs)
45+
args::NamedTuple, kwargs::NamedTuple, IsValidated::Bool=false) =
46+
StatsSpec{T,IsValidated}(name, args, kwargs)
4747

48-
==(a::StatsSpec{T}, b::StatsSpec{T}) where {T<:AbstractStatsProcedure} =
49-
a.args == b.args && a.kwargs == b.kwargs
48+
"""
49+
==(x::StatsSpec{T}, y::StatsSpec{T}) where T
50+
51+
Test whether two instances of [`StatsSpec`](@ref)
52+
with the same parameter `T` also have the same fields `args` and `kwargs`.
53+
54+
See also [`≊`](@ref).
55+
"""
56+
==(x::StatsSpec{T}, y::StatsSpec{T}) where T =
57+
x.args == y.args && x.kwargs == y.kwargs
58+
59+
"""
60+
≊(x::StatsSpec{T}, y::StatsSpec{T}) where T
61+
62+
Test whether two instances of [`StatsSpec`](@ref)
63+
with the same parameter `T` also have the fields `args` and `kwargs`
64+
containing the same sets of key-value pairs
65+
while ignoring the orders.
66+
"""
67+
(x::StatsSpec{T}, y::StatsSpec{T}) where T =
68+
x.args y.args && x.kwargs y.kwargs
5069

5170
isnamed(sp::StatsSpec) = sp.name != ""
5271

src/terms.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ Return a `Tuple` of three objects extracted from the right-hand-side of `formula
6262
Error will be raised if either existence or uniqueness of the `TreatmentTerm` is violated.
6363
"""
6464
function parse_treat(@nospecialize(formula::FormulaTerm))
65-
# Use Array instead of Dict for detecting duplicate terms
65+
# Use Array for detecting duplicate terms
6666
treats = Pair{TreatmentTerm,Tuple}[]
6767
for term in eachterm(formula.rhs)
6868
if hastreat(term)

src/utils.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ such that `==` returns true if two instances have the same field values.
66
"""
77
macro fieldequal(Supertype)
88
return esc(quote
9-
function ==(a::T, b::T) where T <: $Supertype
9+
function ==(x::T, y::T) where T <: $Supertype
1010
f = fieldnames(T)
11-
getfield.(Ref(a),f) == getfield.(Ref(b),f)
11+
getfield.(Ref(x),f) == getfield.(Ref(y),f)
1212
end
1313
end)
1414
end
@@ -110,6 +110,18 @@ function args_kwargs(exprs)
110110
return args, kwargs
111111
end
112112

113+
"""
114+
≊(x::NamedTuple, y::NamedTuple)
115+
116+
Test whether two instances of `NamedTuple` contain
117+
the same set of key-value pairs while ignoring the order.
118+
119+
See https://discourse.julialang.org/t/check-equality-of-two-namedtuples-with-order-of-the-fields-ignored
120+
"""
121+
(x::NamedTuple{N1,T1}, y::NamedTuple{N2,T2}) where {N1,T1,N2,T2} =
122+
length(N1) === length(union(N1,N2)) &&
123+
all(k->getfield(x,k)==getfield(y,k), keys(x))
124+
113125
"""
114126
exampledata()
115127

test/did.jl

Lines changed: 58 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ did(::Type{TestDID}, ::AbstractTreatment, ::AbstractParallel;
1313
@test eltype(d) == Function
1414
@test_throws BoundsError d[1]
1515
@test collect(d) == Function[]
16+
@test iterate(d) === nothing
1617

1718
d = TestDID()
1819
@test length(d) == 2
@@ -23,6 +24,9 @@ did(::Type{TestDID}, ::AbstractTreatment, ::AbstractParallel;
2324
@test_throws BoundsError d[3]
2425
@test_throws MethodError d[:a]
2526
@test collect(d) == Function[print, println]
27+
@test iterate(d) == (print, 2)
28+
@test iterate(d, 2) == (println, 3)
29+
@test iterate(d, 3) === nothing
2630
end
2731

2832
@testset "did wrapper" begin
@@ -68,26 +72,26 @@ end
6872
end
6973

7074
@testset "parse_didargs" begin
71-
@test parse_didargs() == (DefaultDID, "", Dict{Symbol,Any}(), Dict{Symbol,Any}())
72-
@test parse_didargs("test") == (DefaultDID, "test", Dict{Symbol,Any}(), Dict{Symbol,Any}())
75+
@test parse_didargs() == (DefaultDID, "", NamedTuple(), NamedTuple())
76+
@test parse_didargs("test") == (DefaultDID, "test", NamedTuple(), NamedTuple())
7377

7478
sptype, name, args, kwargs = parse_didargs(TestDID, TR, PR, a=1, b=2)
7579
@test sptype == TestDID
7680
@test name == ""
77-
@test args == Dict(:tr=>TR, :pr=>PR)
78-
@test kwargs == Dict(:a=>1, :b=>2)
81+
@test args == (tr=TR, pr=PR)
82+
@test kwargs == (a=1, b=2)
7983

8084
sptype, name, args, kwargs = parse_didargs("test", testterm, TestDID)
8185
@test sptype == TestDID
8286
@test name == "test"
83-
@test args == Dict(:tr=>TR, :pr=>PR)
84-
@test kwargs == Dict(:treatname=>:g)
87+
@test args == (tr=TR, pr=PR)
88+
@test kwargs == (treatname=:g,)
8589

8690
sptype, name, args0, kwargs0 = parse_didargs(TestDID, term(:y) ~ testterm, "test")
8791
@test sptype == TestDID
8892
@test name == "test"
89-
@test args0 == Dict(:tr=>TR, :pr=>PR)
90-
@test kwargs0 == Dict(:yterm=>term(:y), :treatname=>:g)
93+
@test args0 == (tr=TR, pr=PR)
94+
@test kwargs0 == (yterm=term(:y), treatname=:g)
9195

9296
sptype, name, args1, kwargs1 = parse_didargs(TestDID, @formula(y ~ treat(g, ttreat(t, 0), tpara(0))))
9397
@test sptype == TestDID
@@ -96,9 +100,8 @@ end
96100
@test kwargs1 == kwargs0
97101

98102
sptype, name, args0, kwargs0 = parse_didargs(TestDID, term(:y) ~ testterm & term(:z) + term(:x))
99-
@test args0 == Dict(:tr=>TR, :pr=>PR)
100-
@test kwargs0 == Dict(:yterm=>term(:y), :treatname=>:g,
101-
:treatintterms=>(term(:z),), :xterms=>(term(:x),))
103+
@test args0 == (tr=TR, pr=PR)
104+
@test kwargs0 == (yterm=term(:y), treatname=:g, treatintterms=(term(:z),), xterms=(term(:x),))
102105

103106
sptype, name, args1, kwargs1 = parse_didargs(TestDID,
104107
@formula(y ~ treat(g, ttreat(t, 0), tpara(0)) & z + x))
@@ -111,30 +114,52 @@ end
111114
end
112115

113116
@testset "StatsSpec" begin
117+
@testset "== ≊" begin
118+
sp1 = StatsSpec(DefaultDID, "", NamedTuple(), NamedTuple())
119+
sp2 = StatsSpec(DefaultDID, "name", NamedTuple(), NamedTuple())
120+
@test sp1 == sp2
121+
@test sp1 sp2
122+
123+
sp1 = StatsSpec(DefaultDID, "", (tr=TR, pr=PR), (a=1, b=2))
124+
sp2 = StatsSpec(DefaultDID, "", (pr=PR, tr=TR), (b=2, a=1))
125+
@test sp1 != sp2
126+
@test sp1 sp2
127+
128+
sp2 = StatsSpec(DefaultDID, "", (tr=TR, pr=PR), (a=1.0, b=2.0))
129+
@test sp1 == sp2
130+
@test sp1 sp2
131+
132+
sp2 = StatsSpec(DefaultDID, "", (tr=TR, pr=PR), (a=1, b=2, c=3))
133+
@test !(sp1 sp2)
134+
135+
sp2 = StatsSpec(DefaultDID, "", (tr=TR, pr=PR), (a=1, b=1))
136+
@test !(sp1 sp2)
137+
end
138+
114139
@testset "show" begin
115-
sp = StatsSpec(DefaultDID, "", Dict{Symbol,Any}(), Dict{Symbol,Any}())
140+
sp = StatsSpec(DefaultDID, "", NamedTuple(), NamedTuple())
116141
@test sprint(show, sp) == "StatsSpec{DefaultDID}"
117142
@test sprintcompact(sp) == "StatsSpec{DefaultDID}"
118143

119-
sp = StatsSpec(DefaultDID, "name", Dict{Symbol,Any}(), Dict{Symbol,Any}())
144+
sp = StatsSpec(DefaultDID, "name", NamedTuple(), NamedTuple())
120145
@test sprint(show, sp) == "StatsSpec{DefaultDID}: name"
121146
@test sprintcompact(sp) == "StatsSpec{DefaultDID}: name"
122147

123-
sp = StatsSpec(TestDID, "", Dict(:tr=>dynamic(:time,-1), :pr=>nevertreated(-1)),
124-
Dict{Symbol,Any}())
148+
sp = StatsSpec(TestDID, "", (tr=dynamic(:time,-1), pr=nevertreated(-1)),
149+
NamedTuple())
125150
@test sprint(show, sp) == """
126151
StatsSpec{TestDID}:
127152
Dynamic{S}(-1)
128153
NeverTreated{U,P}([-1])"""
129154
@test sprintcompact(sp) == "StatsSpec{TestDID}"
130155

131-
sp = StatsSpec(TestDID, "", Dict(:tr=>dynamic(:time,-1)), Dict{Symbol,Any}())
156+
sp = StatsSpec(TestDID, "", (tr=dynamic(:time,-1),), NamedTuple())
132157
@test sprint(show, sp) == """
133158
StatsSpec{TestDID}:
134159
Dynamic{S}(-1)"""
135160
@test sprintcompact(sp) == "StatsSpec{TestDID}"
136161

137-
sp = StatsSpec(TestDID, "name", Dict(:pr=>nevertreated(-1)), Dict{Symbol,Any}())
162+
sp = StatsSpec(TestDID, "name", (pr=nevertreated(-1),), NamedTuple())
138163
@test sprint(show, sp) == """
139164
StatsSpec{TestDID}: name
140165
NeverTreated{U,P}([-1])"""
@@ -145,11 +170,11 @@ end
145170
@testset "didspec" begin
146171
testname = "name"
147172

148-
sp = StatsSpec(DefaultDID, "", Dict{Symbol,Any}(), Dict{Symbol,Any}())
173+
sp = StatsSpec(DefaultDID, "", NamedTuple(), NamedTuple())
149174
@test didspec() == sp
150175
@test sp.name == ""
151176

152-
sp = StatsSpec(DefaultDID, "name", Dict{Symbol,Any}(), Dict{Symbol,Any}())
177+
sp = StatsSpec(DefaultDID, "name", NamedTuple(), NamedTuple())
153178
@test didspec("") == sp
154179
@test sp.name == "name"
155180

@@ -165,31 +190,30 @@ end
165190
@test sp == didspec()
166191
@test sp.name == "name"
167192

168-
sp0 = StatsSpec(TestDID, "", Dict(:tr=>TR, :pr=>PR), Dict(:a=>1, :b=>2))
193+
sp0 = StatsSpec(TestDID, "", (tr=TR, pr=PR), (a=1, b=2))
169194
sp1 = didspec(TestDID, TR, PR, a=1, b=2)
170-
@test sp1 == sp0
171-
@test sp0 == @didspec TR a=1 b=2 PR TestDID
195+
@test sp1 === sp0
196+
@test sp0 === @didspec TR a=1 b=2 PR TestDID
172197

173-
sp2 = StatsSpec(TestDID, "name", Dict(:tr=>TR, :pr=>PR), Dict(:a=>1, :b=>2))
198+
sp2 = StatsSpec(TestDID, "name", (tr=TR, pr=PR), (a=1, b=2))
174199
sp3 = didspec("name", TR, PR, TestDID, b=2, a=1)
175-
@test sp3 == sp2
176-
@test sp3 == sp1
177-
@test sp2 == @didspec TR PR TestDID "name" b=2 a=1
200+
@test sp2 == sp1
201+
@test sp3 sp2
202+
@test sp3 === @didspec TR PR TestDID "name" b=2 a=1
178203

179-
sp4 = StatsSpec(TestDID, testname, Dict(:tr=>TR, :pr=>PR), Dict(:treatname=>:g))
204+
sp4 = StatsSpec(TestDID, testname, (tr=TR, pr=PR), (treatname=:g,))
180205
sp5 = didspec(TestDID, testterm)
181206
@test sp5 == sp4
182-
@test sp4 == @didspec TestDID testterm testname
207+
@test sp4 === @didspec TestDID testterm testname
183208

184-
sp6 = StatsSpec(TestDID, "name", Dict(:tr=>TR, :pr=>PR),
185-
Dict(:yterm=>term(:y), :treatname=>:g,
186-
:treatintterms=>(term(:z),), :xterms=>(term(:x),)))
209+
sp6 = StatsSpec(TestDID, "", (tr=TR, pr=PR),
210+
(yterm=term(:y), treatname=:g, treatintterms=(term(:z),), xterms=(term(:x),)))
187211
sp7 = didspec(TestDID, term(:y) ~ testterm & term(:z) + term(:x))
188212
sp8 = didspec(TestDID, @formula(y ~ treat(g, ttreat(t, 0), tpara(0)) & z + x))
189213
@test sp7 == sp6
190-
@test sp8 == sp7
191-
@test sp6 == @didspec TestDID term(:y) ~ testterm & term(:z) + term(:x)
192-
@test sp6 == @didspec TestDID @formula(y ~ treat(g, ttreat(t, 0), tpara(0)) & z + x)
214+
@test sp8 === sp7
215+
@test sp6 === @didspec TestDID term(:y) ~ testterm & term(:z) + term(:x)
216+
@test sp6 === @didspec TestDID @formula(y ~ treat(g, ttreat(t, 0), tpara(0)) & z + x)
193217
end
194218

195219
@testset "@did" begin

0 commit comments

Comments
 (0)