Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit 52f9b16

Browse files
authored
Add RotatingTimeArray and improve settime (#30)
1 parent a2bce2b commit 52f9b16

File tree

9 files changed

+231
-167
lines changed

9 files changed

+231
-167
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1818
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1919
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
2020
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
21+
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
2122
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
2223

2324
[compat]
@@ -34,6 +35,7 @@ Reexport = "0.2, 1"
3435
StatsBase = "0.33"
3536
StatsFuns = "0.9"
3637
StatsModels = "0.6.18"
38+
StructArrays = "0.5"
3739
Tables = "1.2"
3840
julia = "1.3"
3941

src/DiffinDiffsBase.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ using StatsBase: CoefTable, Weights, stderror, uweights
1717
using StatsFuns: tdistccdf, tdistinvcdf
1818
@reexport using StatsModels
1919
using StatsModels: Schema
20+
using StructArrays: StructArray
2021
using Tables
2122
using Tables: AbstractColumns, istable, columnnames, getcolumn
2223

@@ -37,7 +38,7 @@ export cb,
3738

3839
RotatingTimeValue,
3940
rotatingtime,
40-
RotatingRange,
41+
RotatingTimeArray,
4142

4243
VecColumnTable,
4344
VecColsRow,

src/ScaledArrays.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ end
88
"""
99
ScaledArray{T,R,N,RA,P} <: AbstractArray{T,N}
1010
11-
An array type that stores data as indices of a range.
11+
Array type that stores data as indices of a range.
1212
1313
# Fields
1414
- `refs::RA<:AbstractArray{R,N}`: an array of indices.
@@ -189,10 +189,7 @@ ScaledArray(sa::ScaledArray, step=nothing; reftype::Type=eltype(refarray(sa)),
189189
start=nothing, stop=nothing, xtype::Type=eltype(sa), usepool::Bool=true) =
190190
ScaledArray(sa, reftype, xtype, start, step, stop, usepool)
191191

192-
Base.similar(sa::ScaledArray{T,R}, dims::Dims=size(sa)) where {T,R} =
193-
ScaledArray(RefArray(ones(R, dims)), DataAPI.refpool(sa), Dict{T,R}())
194-
195-
Base.similar(sa::SubArray{<:Any, <:Any, <:ScaledArray{T,R}}, dims::Dims=size(sa)) where {T,R} =
192+
Base.similar(sa::ScaledArrOrSub{T,R}, dims::Dims=size(sa)) where {T,R} =
196193
ScaledArray(RefArray(ones(R, dims)), DataAPI.refpool(sa), Dict{T,R}())
197194

198195
Base.similar(sa::ScaledArrOrSub, dims::Int...) = similar(sa, dims)

src/operations.jl

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ end
1111
function _refs_pool(col::AbstractArray, reftype::Type{<:Integer}=UInt32)
1212
refs = refarray(col)
1313
pool = refpool(col)
14-
labeled = pool !== nothing && !(pool isa RotatingTimeRange)
14+
labeled = pool !== nothing
1515
if !labeled
1616
refs, invpool, pool = _label(col, eltype(col), reftype)
1717
end
@@ -97,7 +97,7 @@ function cellrows(cols::VecColumnTable, refrows::IdDict)
9797
columns = Vector{AbstractVector}(undef, ncol)
9898
for i in 1:ncol
9999
c = cols[i]
100-
if typeof(c) <: ScaledArrOrSub
100+
if typeof(c) <: Union{ScaledArrOrSub, RotatingTimeArray}
101101
columns[i] = similar(c, ncell)
102102
else
103103
columns[i] = Vector{eltype(c)}(undef, ncell)
@@ -127,76 +127,66 @@ function cellrows(cols::VecColumnTable, refrows::IdDict)
127127
end
128128

129129
"""
130-
settime(data, timename; step, start, stop, reftype, rotation)
131-
settime(time::AbstractArray; step, start, stop, reftype, rotation)
130+
settime(time::AbstractArray, step; start, stop, reftype, rotation)
132131
133132
Convert a column of time values to a [`ScaledArray`](@ref)
134133
for representing discretized time periods of uniform length.
135-
Time values can be provided either as a table containing the relevant column or as an array.
134+
If `rotation` is specified (time values belong to multiple rotation groups),
135+
a [`RotatingTimeArray`](@ref) is returned with the `time` field
136+
being a [`ScaledArray`](@ref).
136137
The returned array ensures well-defined time intervals for operations involving relative time
137138
(such as [`lag`](@ref) and [`diff`](@ref)).
138139
See also [`aligntime`](@ref).
139140
140141
# Arguments
141-
- `data`: a Tables.jl-compatible data table.
142-
- `timename::Union{Symbol,Integer}`: the name of the column in `data` that contains time values.
143-
- `time::AbstractArray`: the array containing time values (only needed for the alternative method).
142+
- `time::AbstractArray`: the array containing time values.
143+
- `step=nothing`: the length of each time interval; try `step=one(eltype(time))` if not specified.
144144
145145
# Keywords
146-
- `step=nothing`: the length of each time interval; try step=1 if not specified.
147146
- `start=nothing`: the first element of the `pool` of the returned [`ScaledArray`](@ref).
148147
- `stop=nothing`: the last element of the `pool` of the returned [`ScaledArray`](@ref).
149-
- `reftype::Type{<:Signed}=Int32`: the element type of the reference values for the returned [`ScaledArray`](@ref).
150-
- `rotation=nothing`: rotation groups in a rotating sampling design; use [`RotatingTimeValue`](@ref)s as reference values.
148+
- `reftype::Type{<:Signed}=Int32`: the element type of the reference values for the [`ScaledArray`](@ref).
149+
- `rotation=nothing`: rotation groups in a rotating sampling design.
151150
"""
152-
function settime(time::AbstractArray; step=nothing, start=nothing, stop=nothing,
151+
function settime(time::AbstractArray, step=nothing; start=nothing, stop=nothing,
153152
reftype::Type{<:Signed}=Int32, rotation=nothing)
154153
T = eltype(time)
155154
T <: ValidTimeType && !(T <: RotatingTimeValue) ||
156155
throw(ArgumentError("unaccepted element type $T from time column"))
157156
step === nothing && (step = one(T))
158157
time = ScaledArray(time, start, step, stop; reftype=reftype)
159-
if rotation !== nothing
160-
refs = rotatingtime(rotation, time.refs)
161-
rots = unique(rotation)
162-
invpool = Dict{RotatingTimeValue{eltype(rotation), T}, eltype(refs)}()
163-
for (k, v) in time.invpool
164-
for r in rots
165-
rt = RotatingTimeValue(r, k)
166-
invpool[rt] = RotatingTimeValue(r, v)
167-
end
168-
end
169-
rmin, rmax = extrema(rots)
170-
pool = RotatingTimeValue(rmin, first(time.pool)):scale(time):RotatingTimeValue(rmax, last(time.pool))
171-
time = ScaledArray(RefArray(refs), pool, invpool)
172-
end
158+
rotation === nothing || (time = RotatingTimeArray(rotation, time))
173159
return time
174160
end
175161

176-
function settime(data, timename::Union{Symbol,Integer};
177-
step=nothing, start=nothing, stop=nothing,
178-
reftype::Type{<:Signed}=Int32, rotation=nothing)
179-
checktable(data)
180-
return settime(getcolumn(data, timename);
181-
step=step, start=start, stop=stop, reftype=reftype, rotation=rotation)
182-
end
183-
184162
"""
163+
aligntime(col::AbstractArray, time::ScaledArrOrSub)
164+
aligntime(col::AbstractArray, time::RotatingTimeArray)
185165
aligntime(data, colname::Union{Symbol,Integer}, timename::Union{Symbol,Integer})
186166
187-
Convert a column of time values indexed by `colname` from `data` table
188-
to a [`ScaledArray`](@ref) with a `pool`
167+
Convert a column of time values `col` to a [`ScaledArray`](@ref) with a `pool`
189168
that has the same first element and step size as the `pool` from
190-
the [`ScaledArray`](@ref) indexed by `timename`.
169+
the [`ScaledArray`](@ref) `time`.
170+
If `time` is a [`RotatingTimeArray`](@ref) with the `time` field being a [`ScaledArray`](@ref),
171+
the returned array is also a [`RotatingTimeArray`](@ref)
172+
with the `time` field being the converted [`ScaledArray`](@ref).
173+
Alternative, the arrays may be specified with a Tables.jl-compatible `data` table
174+
and column indices `colname` and `timename`.
191175
See also [`settime`](@ref).
192176
193177
This is useful for representing all discretized time periods with the same scale
194178
so that the underlying reference values returned by `DataAPI.refarray`
195179
can be directly comparable across the columns.
196180
"""
181+
aligntime(col::AbstractArray, time::ScaledArrOrSub) = align(col, time)
182+
aligntime(col::AbstractArray, time::RotatingTimeArray) =
183+
RotatingTimeArray(time.rotation, align(col, time.time))
184+
aligntime(col::RotatingTimeArray, time::RotatingTimeArray) =
185+
RotatingTimeArray(time.rotation, align(col.time, time.time))
186+
197187
function aligntime(data, colname::Union{Symbol,Integer}, timename::Union{Symbol,Integer})
198188
checktable(data)
199-
return align(getcolumn(data, colname), getcolumn(data, timename))
189+
return aligntime(getcolumn(data, colname), getcolumn(data, timename))
200190
end
201191

202192
"""
@@ -246,7 +236,7 @@ that is returned by [`settime`](@ref).
246236
- `time::AbstractArray`: the array containing time values (only needed for the alternative method).
247237
248238
# Keywords
249-
- `step=nothing`: the length of each time interval; try step=1 if not specified.
239+
- `step=nothing`: the length of each time interval; try `step=one(eltype(time))` if not specified.
250240
- `reftype::Type{<:Signed}=Int32`: the element type of the reference values for [`PanelStructure`](@ref).
251241
- `rotation=nothing`: rotation groups in a rotating sampling design; use [`RotatingTimeValue`](@ref)s as reference values.
252242
@@ -263,7 +253,7 @@ function setpanel(id::AbstractArray, time::AbstractArray; step=nothing,
263253
"id has length $(length(id)) while time has length $(length(time))"))
264254
refs, idpool, labeled = _refs_pool(id)
265255
labeled && (refs = copy(refs); idpool = copy(idpool))
266-
time = settime(time; step=step, reftype=reftype, rotation=rotation)
256+
time = settime(time, step; reftype=reftype, rotation=rotation)
267257
trefs = refarray(time)
268258
tpool = refpool(time)
269259
# Multiply 2 to create enough gaps between id groups for the largest possible lead/lag

src/procedures.jl

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,19 @@ const GroupTerms = StatsStep{:GroupTerms, typeof(groupterms), false}
6767

6868
required(::GroupTerms) = (:treatintterms, :xterms)
6969

70+
function _checkscales(col1::AbstractArray, col2::AbstractArray, treatvars::Vector{Symbol})
71+
if col1 isa ScaledArrOrSub || col2 isa ScaledArrOrSub
72+
col1 isa ScaledArrOrSub && col2 isa ScaledArrOrSub ||
73+
throw(ArgumentError("time fields in both columns $(treatvars[1]) and $(treatvars[2]) must be ScaledArrays; see settime and aligntime"))
74+
first(DataAPI.refpool(col1)) == first(DataAPI.refpool(col2)) &&
75+
scale(col1) == scale(col2) || throw(ArgumentError(
76+
"time fields in columns $(treatvars[1]) and $(treatvars[2]) are not aligned; see aligntime"))
77+
else
78+
eltype(col1) <: Integer || throw(ArgumentError(
79+
"columns $(treatvars[1]) and $(treatvars[2]) must be stored in ScaledArrays; see settime and aligntime"))
80+
end
81+
end
82+
7083
function checktreatvars(::DynamicTreatment{SharpDesign}, pr::TrendParallel{Unconditional},
7184
treatvars::Vector{Symbol}, data)
7285
# treatvars should be cohort and time variables
@@ -76,22 +89,22 @@ function checktreatvars(::DynamicTreatment{SharpDesign}, pr::TrendParallel{Uncon
7689
T2 = nonmissingtype(eltype(col2))
7790
T1 == T2 || throw(ArgumentError(
7891
"nonmissing elements from columns $(treatvars[1]) and $(treatvars[2]) have different types $T1 and $T2"))
79-
T1 <: Union{Integer, RotatingTimeValue{<:Any, <:Integer}} ||
80-
col1 isa ScaledArrOrSub && col2 isa ScaledArrOrSub ||
81-
throw(ArgumentError("columns $(treatvars[1]) and $(treatvars[2]) must either have integer elements or be ScaledArrays; see settime and aligntime"))
8292
T1 <: ValidTimeType ||
8393
throw(ArgumentError("column $(treatvars[1]) has unaccepted element type $(T1)"))
8494
eltype(pr.e) == T1 || throw(ArgumentError("element type $(eltype(pr.e)) of control cohorts from $pr does not match element type $T1 from data; expect $T1"))
85-
if col1 isa ScaledArrOrSub
86-
first(DataAPI.refpool(col1)) == first(DataAPI.refpool(col2)) &&
87-
scale(col1) == scale(col2) || throw(ArgumentError(
88-
"time values in columns $(treatvars[1]) and $(treatvars[2]) are not aligned; see aligntime"))
95+
if T1 <: RotatingTimeValue
96+
col1 isa RotatingTimeArray && col2 isa RotatingTimeArray ||
97+
throw(ArgumentError("columns $(treatvars[1]) and $(treatvars[2]) must be RotatingTimeArrays; see settime"))
98+
_checkscales(col1.time, col2.time, treatvars)
99+
else
100+
_checkscales(col1, col2, treatvars)
89101
end
90102
end
91103

92104
function _overlaptime(tr::DynamicTreatment, tr_rows::BitVector, data)
93-
control_time = Set(view(refarray(getcolumn(data, tr.time)), .!tr_rows))
94-
treated_time = Set(view(refarray(getcolumn(data, tr.time)), tr_rows))
105+
timeref = refarray(getcolumn(data, tr.time))
106+
control_time = Set(view(timeref, .!tr_rows))
107+
treated_time = Set(view(timeref, tr_rows))
95108
return intersect(control_time, treated_time), control_time, treated_time
96109
end
97110

@@ -108,20 +121,26 @@ end
108121
function overlap!(esample::BitVector, tr_rows::BitVector, aux::BitVector, tr::DynamicTreatment,
109122
pr::NotYetTreatedParallel{Unconditional}, treatname::Symbol, data)
110123
overlap_time, _c, _t = _overlaptime(tr, tr_rows, data)
111-
timetype = eltype(overlap_time)
112-
invpool = invrefpool(getcolumn(data, tr.time))
113-
e = invpool === nothing ? Set(pr.e) : Set(invpool[c] for c in pr.e)
114-
if !(timetype <: RotatingTimeValue)
124+
timecol = getcolumn(data, tr.time)
125+
if !(eltype(timecol) <: RotatingTimeValue)
126+
invpool = invrefpool(timecol)
127+
e = invpool === nothing ? Set(pr.e) : Set(invpool[c] for c in pr.e)
115128
ecut = invpool === nothing ? pr.ecut[1] : invpool[pr.ecut[1]]
116129
filter!(x -> x < ecut, overlap_time)
117130
isvalidcohort = x -> x < ecut || x in e
118131
else
119-
ecut = invpool === nothing ? pr.ecut : (invpool[e] for e in pr.ecut)
120-
ecut = IdDict(e.rotation=>e.time for e in ecut)
132+
invpool = invrefpool(timecol.time)
133+
if invpool === nothing
134+
e = Set(pr.e)
135+
ecut = IdDict(e.rotation=>e.time for e in pr.ecut)
136+
else
137+
e = Set(RotatingTimeValue(c.rotation, invpool[c.time]) for c in pr.e)
138+
ecut = IdDict(e.rotation=>invpool[e.time] for e in pr.ecut)
139+
end
121140
filter!(x -> x.time < ecut[x.rotation], overlap_time)
122141
isvalidcohort = x -> x.time < ecut[x.rotation] || x in e
123142
end
124-
aux[esample] .= view(refarray(getcolumn(data, tr.time)), esample) .∈ (overlap_time,)
143+
aux[esample] .= view(refarray(timecol), esample) .∈ (overlap_time,)
125144
esample[esample] .&= view(aux, esample)
126145
aux[esample] .= isvalidcohort.(view(refarray(getcolumn(data, treatname)), esample))
127146
esample[esample] .&= view(aux, esample)

0 commit comments

Comments
 (0)