Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit cc83160

Browse files
committed
Add a type parameter for AbstractStatsProcedure
1 parent 5186049 commit cc83160

File tree

5 files changed

+57
-16
lines changed

5 files changed

+57
-16
lines changed

src/DiffinDiffsBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using StatsBase
66
@reexport using StatsModels
77

88
import Base: ==, show
9+
import Base: eltype, getindex, iterate, length
910

1011
export @fieldequal,
1112
eachterm,

src/did.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
"""
2-
DiffinDiffsEstimator <: AbstractStatsProcedure
2+
DiffinDiffsEstimator{T} <: AbstractStatsProcedure{T}
33
4-
Supertype for all types specifying the estimation procedure for difference-in-differences.
4+
Specify the estimation procedure for difference-in-differences.
55
"""
6-
abstract type DiffinDiffsEstimator <: AbstractStatsProcedure end
6+
struct DiffinDiffsEstimator{T} <: AbstractStatsProcedure{T} end
77

88
"""
99
DefaultDID <: DiffinDiffsEstimator
1010
1111
Default difference-in-differences estimator selected based on the context.
1212
"""
13-
struct DefaultDID <: DiffinDiffsEstimator end
13+
const DefaultDID = DiffinDiffsEstimator{Tuple{}}
14+
15+
show(io::IO, d::Type{DefaultDID}) = print(io, "DefaultDID")
1416

1517
did(tr::AbstractTreatment, pr::AbstractParallel; kwargs...) =
1618
did(DefaultDID, tr, pr; kwargs...)

src/procedures.jl

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,50 @@
11
"""
2-
AbstractStatsProcedure
2+
AbstractStatsProcedure{T<:NTuple{N,Function} where N}
33
44
Supertype for all types specifying the procedure for statistical estimation or inference.
5+
6+
The procedure is determined by the `parameters` of `T`,
7+
which are types of a sequence of functions.
58
"""
6-
abstract type AbstractStatsProcedure end
9+
abstract type AbstractStatsProcedure{T<:NTuple{N,Function} where N} end
10+
11+
length(p::AbstractStatsProcedure{T}) where T = length(T.parameters)
12+
eltype(::Type{<:AbstractStatsProcedure}) = Function
13+
14+
function getindex(p::AbstractStatsProcedure{T}, i) where T
15+
fs = T.parameters[i]
16+
return fs isa Type && fs <: Function ? fs.instance : [f.instance for f in fs]
17+
end
18+
19+
iterate(p::AbstractStatsProcedure{T}, state=1) where T =
20+
state > length(p) ? nothing : (p[state], state+1)
721

822
"""
9-
StatsSpec{T<:AbstractStatsProcedure}
23+
StatsSpec{T<:AbstractStatsProcedure, IsComplete}
1024
11-
Record the specification for a statistical procedure and
12-
optionally a name for the specification.
25+
Record the specification for a statistical procedure of type `T`
26+
that may or may not be verified to be complete as indicated by `IsComplete`.
1327
1428
The specification is recorded based on the arguments
1529
for a function that will conduct the procedure.
1630
It is assumed that a tuple of positional arguments accepted by the function
1731
can be constructed solely based on the type of each argument.
1832
1933
# Fields
20-
- `name::String`: name for the specification.
34+
- `name::String`: an optional name for the specification.
2135
- `args::Dict{Symbol}`: positional arguments indexed based on their types.
2236
- `kwargs::Dict{Symbol}`: keyword arguments.
2337
"""
24-
struct StatsSpec{T<:AbstractStatsProcedure}
38+
struct StatsSpec{T<:AbstractStatsProcedure, IsComplete}
2539
name::String
2640
args::Dict{Symbol}
2741
kwargs::Dict{Symbol}
28-
StatsSpec(T::Type{<:AbstractStatsProcedure},
29-
name::String, args::Dict{Symbol}, kwargs::Dict{Symbol}) =
30-
new{T}(name, args, kwargs)
3142
end
3243

44+
StatsSpec(T::Type{<:AbstractStatsProcedure}, name::String,
45+
args::Dict{Symbol}, kwargs::Dict{Symbol}, IsComplete::Bool=false) =
46+
StatsSpec{T,IsComplete}(name, args, kwargs)
47+
3348
==(a::StatsSpec{T}, b::StatsSpec{T}) where {T<:AbstractStatsProcedure} =
3449
a.args == b.args && a.kwargs == b.kwargs
3550

test/did.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,24 @@ did(::Type{TestDID}, ::AbstractTreatment, ::AbstractParallel;
77
yterm=:unknown, treatname=:unknown, treatintterms=nothing, xterms=nothing) =
88
(yterm, treatname, treatintterms, xterms)
99

10+
@testset "DiffinDiffsEstimator" begin
11+
d = DefaultDID()
12+
@test length(d) == 0
13+
@test eltype(d) == Function
14+
@test_throws BoundsError d[1]
15+
@test collect(d) == Function[]
16+
17+
d = TestDID()
18+
@test length(d) == 2
19+
@test eltype(d) == Function
20+
@test d[1] == print
21+
@test d[1:2] == [print, println]
22+
@test d[[2,1]] == [println, print]
23+
@test_throws BoundsError d[3]
24+
@test_throws MethodError d[:a]
25+
@test collect(d) == Function[print, println]
26+
end
27+
1028
@testset "did wrapper" begin
1129
@testset "DefaultDID" begin
1230
@test_throws ErrorException did(TR, PR)

test/testutils.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Define simple generic types and methods for testing
22

3+
import Base: show
4+
35
struct TestTreatment <: AbstractTreatment
46
time::Symbol
57
ref::Int
@@ -14,8 +16,11 @@ end
1416
TestParallel(e::Int) = TestParallel{ParallelCondition,ParallelStrength}(e)
1517
tpara(c::ConstantTerm) = TestParallel{ParallelCondition,ParallelStrength}(c.n)
1618

17-
struct NotImplemented <: DiffinDiffsEstimator end
18-
struct TestDID <: DiffinDiffsEstimator end
19+
const NotImplemented = DiffinDiffsEstimator{Tuple{typeof(println)}}
20+
show(io::IO, d::Type{NotImplemented}) = print(io, "NotImplemented")
21+
22+
const TestDID = DiffinDiffsEstimator{Tuple{typeof(print), typeof(println)}}
23+
show(io::IO, d::Type{TestDID}) = print(io, "TestDID")
1924

2025
const TR = TestTreatment(:t, 0)
2126
const PR = TestParallel(0)

0 commit comments

Comments
 (0)