Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit 9803fa1

Browse files
committed
Improve specset and did
1 parent b276c26 commit 9803fa1

File tree

8 files changed

+333
-215
lines changed

8 files changed

+333
-215
lines changed

src/DiffinDiffsBase.jl

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,11 @@ using Tables: istable, getcolumn
1111

1212
import Base: ==, show, union
1313
import Base: eltype, firstindex, lastindex, getindex, iterate, length, sym_in
14-
1514
import StatsModels: termvars
1615

1716
export TupleTerm
1817

19-
export @fieldequal,
20-
eachterm,
21-
cb,
22-
@unpack,
18+
export cb,
2319
,
2420
exampledata,
2521

@@ -48,8 +44,6 @@ export @fieldequal,
4844

4945
TreatmentTerm,
5046
treat,
51-
hastreat,
52-
parse_treat,
5347

5448
StatsStep,
5549
namedargs,
@@ -58,8 +52,8 @@ export @fieldequal,
5852
PooledStatsProcedure,
5953
pool,
6054
StatsSpec,
61-
@specset,
6255
proceed,
56+
@specset,
6357

6458
CheckData,
6559
CheckVars,
@@ -68,11 +62,8 @@ export @fieldequal,
6862
DefaultDID,
6963
did,
7064
didspec,
71-
@didspec,
7265
@did,
73-
DIDResult,
74-
agg,
75-
AggregatedDIDResult
66+
DIDResult
7667

7768
include("utils.jl")
7869
include("treatments.jl")

src/StatsProcedures.jl

Lines changed: 169 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,11 @@ _combinedargs(::StatsStep, ::Any) = ()
4444
function (step::StatsStep{A,F})(ntargs::NamedTuple; verbose::Bool=false) where {A,F}
4545
haskey(ntargs, :verbose) && (verbose = ntargs.verbose)
4646
verbose && printstyled("Running ", step, "\n", color=:green)
47-
ret, share = F.instance(_getargs(ntargs, step)..., _combinedargs(step, (ntargs,))...)
48-
if ret isa NamedTuple
49-
return merge(ntargs, ret)
50-
elseif ret === nothing
51-
return ntargs
47+
ret = F.instance(_getargs(ntargs, step)..., _combinedargs(step, (ntargs,))...)
48+
if ret isa Tuple{<:NamedTuple, Bool}
49+
return merge(ntargs, ret[1])
5250
else
53-
error("unexpected returned object from function associated with StatsStep")
51+
error("unexpected $(typeof(ret)) returned from $(F.name.mt.name) associated with StatsStep $A")
5452
end
5553
end
5654

@@ -79,6 +77,8 @@ all subtypes of `AbstractStatsProcedure`.
7977
"""
8078
abstract type AbstractStatsProcedure{Alias, T<:NTuple{N,StatsStep} where N} end
8179

80+
_result(::Type{<:AbstractStatsProcedure}, ntargs::NamedTuple) = ntargs
81+
8282
length(::AbstractStatsProcedure{A,T}) where {A,T} = length(T.parameters)
8383
eltype(::Type{<:AbstractStatsProcedure}) = StatsStep
8484
firstindex(::AbstractStatsProcedure{A,T}) where {A,T} = firstindex(T.parameters)
@@ -292,8 +292,8 @@ An optional name for the specification can be attached as parameter `Alias`.
292292
(sp::StatsSpec{A,T})(; verbose::Bool=false, keep=nothing, keepall::Bool=false)
293293
294294
Execute the procedure of type `T` with the arguments specified in `args`.
295-
By default, only an object with a key `result` assigned by a [`StatsStep`](@ref)
296-
or the last value returned by the last [`StatsStep`](@ref) is returned.
295+
By default, a dedicated result object for `T` is returned if it is available.
296+
Otherwise, the last value returned by the last [`StatsStep`](@ref) is returned.
297297
298298
## Keywords
299299
- `verbose::Bool=false`: print the name of each step when it is called.
@@ -331,8 +331,10 @@ _procedure(::StatsSpec{A,T}) where {A,T} = T
331331

332332
function (sp::StatsSpec{A,T})(;
333333
verbose::Bool=false, keep=nothing, keepall::Bool=false) where {A,T}
334-
args = verbose ? merge(sp.args, (verbose=true,)) : sp.args
334+
args = deepcopy(sp.args)
335+
args = verbose ? merge(args, (verbose=true,)) : args
335336
ntall = foldl(|>, T(), init=args)
337+
ntall = _result(T, ntall)
336338
if keepall
337339
return ntall
338340
elseif !isempty(ntall)
@@ -364,77 +366,6 @@ function show(io::IO, ::MIME"text/plain", sp::StatsSpec{A,T}) where {A,T}
364366
_show_args(io, sp)
365367
end
366368

367-
function _spec_walker(x, parsers, formatters, ntargs_set)
368-
@capture(x, StatsSpec(formatter_(parser_(rawargs__))...)(;o__)) || return x
369-
push!(parsers, parser)
370-
push!(formatters, formatter)
371-
length(o) > 0 &&
372-
@warn "[options] specified for individual StatsSpec are ignored inside @specset"
373-
return :(push!($ntargs_set, $parser($(rawargs...))))
374-
end
375-
376-
"""
377-
@specset default_args... begin ... end
378-
@specset default_args... for v in (...) ... end
379-
@specset default_args... for v in (...), w in (...) ... end
380-
381-
Return a vector of [`StatsSpec`](@ref) with shared default values for arguments.
382-
See also [`proceed`](@ref).
383-
384-
# Arguments
385-
- `default_args...`: optional default values for arguments shared by all [`StatsSpec`](@ref)s.
386-
- `code block`: a `begin/end` block or a `for` loop containing arguments for constructing [`StatsSpec`](@ref)s.
387-
388-
# Notes
389-
`@specset` transforms `Expr`s that construct [`StatsSpec`](@ref)
390-
to collect the sets of arguments from the code block
391-
and infers how the arguments entered by users need to be processed
392-
based on the names of functions called within [`StatsSpec`](@ref).
393-
For end users, `Macro`s that generate `Expr`s for these function calls should be provided.
394-
395-
Optional default arguments are merged
396-
with the arguments provided for each individual specification
397-
and supersede the default values specified for each procedure through [`namedargs`](@ref).
398-
These default arguments should be specified in the same pattern as
399-
how arguments are specified for each specification inside the code block,
400-
as `@specset` processes these arguments by calling
401-
the same functions found in the code block.
402-
"""
403-
macro specset(args...)
404-
nargs = length(args)
405-
nargs == 0 && throw(ArgumentError("no argument is found for @specset"))
406-
default_args = nargs > 1 ? _args_kwargs(args[1:end-1]) : nothing
407-
specs = args[end]
408-
isexpr(specs, :block, :for) ||
409-
throw(ArgumentError("last argument to @specset must be begin/end block or for loop"))
410-
411-
parsers, formatters, ntargs_set = Symbol[], Symbol[], NamedTuple[]
412-
blk = postwalk(x->_spec_walker(x, parsers, formatters, ntargs_set), specs)
413-
nparser = length(unique!(parsers))
414-
nparser == 1 ||
415-
throw(ArgumentError("found $nparser parsers from arguments while expecting one"))
416-
nformatter = length(unique!(formatters))
417-
nformatter == 1 ||
418-
throw(ArgumentError("found $nformatter formatters from arguments while expecting one"))
419-
420-
parser, formatter = parsers[1], formatters[1]
421-
if default_args === nothing
422-
defaults = :(NamedTuple())
423-
else
424-
defaults = esc(:($parser($(default_args[1]...); $(default_args[2]...))))
425-
end
426-
427-
return quote
428-
$(esc(blk))
429-
local nspec = length($ntargs_set)
430-
local sps = Vector{StatsSpec}(undef, nspec)
431-
for i in 1:nspec
432-
sps[i] = StatsSpec($(esc(formatter))(merge($defaults, $ntargs_set[i]))...)
433-
end
434-
sps
435-
end
436-
end
437-
438369
"""
439370
proceed(sps::AbstractVector{<:StatsSpec}; kwargs...)
440371
@@ -448,21 +379,21 @@ See also [`@specset`](@ref).
448379
- `keepall::Bool=false`: return all objects generated by procedures along with arguments from the [`StatsSpec`](@ref)s.
449380
450381
# Returns
451-
- `Vector{NamedTuple}`: results for each specification in the same order of `sps`.
382+
- `Vector`: results for each specification in the same order of `sps`.
452383
453-
By default, either the object with a key `result`
384+
By default, either a dedicated result object for the corresponding procedure
454385
or the last value returned by the last [`StatsStep`](@ref)
455-
is contained in each `NamedTuple`.
386+
becomes an element in the returned `Vector` for each [`StatsSpec`](@ref).
456387
When either `keep` or `keepall` is specified,
457-
additional objects are included.
388+
a `NamedTuple` with additional objects is formed for each [`StatsSpec`](@ref).
458389
"""
459390
function proceed(sps::AbstractVector{<:StatsSpec};
460391
verbose::Bool=false, keep=nothing, keepall::Bool=false)
461392
nsps = length(sps)
462393
nsps == 0 && throw(ArgumentError("expect a nonempty vector"))
463394
traces = Vector{NamedTuple}(undef, nsps)
464395
for i in 1:nsps
465-
traces[i] = sps[i].args
396+
traces[i] = deepcopy(sps[i].args)
466397
end
467398
gids = groupfind(r->_procedure(r)(), sps)
468399
steps = pool((p for p in keys(gids))...)
@@ -473,18 +404,23 @@ function proceed(sps::AbstractVector{<:StatsSpec};
473404
taskids = vcat((gids[steps.procs[i]] for i in _sharedby(step))...)
474405
tasks = groupview(r->_getargs(r, step), view(traces, taskids))
475406
for (ins, subtb) in pairs(tasks)
476-
ret, share = _f(step)(ins..., _combinedargs(step, subtb)...)
407+
ret = _f(step)(ins..., _combinedargs(step, subtb)...)
408+
if ret isa Tuple{<:NamedTuple, Bool}
409+
ret, share = ret
410+
else
411+
fname = typeof(_f(step)).name.mt.name
412+
stepname = typeof(step).parameters[1].parameters[1]
413+
error("unexpected $(typeof(ret)) returned from $fname associated with StatsStep $stepname")
414+
end
477415
ntask += 1
478416
ntask_total += 1
479-
if ret !== nothing
480-
if share
481-
for i in eachindex(subtb)
482-
subtb[i] = merge(subtb[i], ret)
483-
end
484-
else
485-
for i in eachindex(subtb)
486-
subtb[i] = merge(subtb[i], deepcopy(ret))
487-
end
417+
if share
418+
for i in eachindex(subtb)
419+
subtb[i] = merge(subtb[i], ret)
420+
end
421+
else
422+
for i in eachindex(subtb)
423+
subtb[i] = merge(subtb[i], deepcopy(ret))
488424
end
489425
end
490426
end
@@ -496,6 +432,9 @@ function proceed(sps::AbstractVector{<:StatsSpec};
496432
verbose && printstyled("All steps finished (", ntask_total,
497433
ntask_total > 1 ? " tasks" : " task", " for ", nprocs,
498434
nprocs > 1 ? " procedures)\n" : " procedure)\n", bold=true, color=:green)
435+
for i in 1:nsps
436+
traces[i] = _result(_procedure(sps[i]), traces[i])
437+
end
499438
if keepall
500439
return traces
501440
elseif keep===nothing
@@ -516,3 +455,138 @@ function proceed(sps::AbstractVector{<:StatsSpec};
516455
return traces
517456
end
518457
end
458+
459+
function _parse!(options::Expr, args)
460+
noproceed = false
461+
for arg in args
462+
# Assume a symbol means the kwarg takes value true
463+
if isa(arg, Symbol)
464+
if arg == :noproceed
465+
noproceed = true
466+
else
467+
key = Expr(:quote, arg)
468+
push!(options.args, Expr(:call, :(=>), key, true))
469+
end
470+
elseif isexpr(arg, :(=))
471+
if arg.args[1] == :noproceed
472+
noproceed = arg.args[2]
473+
else
474+
key = Expr(:quote, arg.args[1])
475+
push!(options.args, Expr(:call, :(=>), key, arg.args[2]))
476+
end
477+
else
478+
throw(ArgumentError("unexpected option $arg"))
479+
end
480+
end
481+
return noproceed
482+
end
483+
484+
function _spec_walker1(x, parsers, formatters, ntargs_set)
485+
@capture(x, StatsSpec(formatter_(parser_(rawargs__))...)(;o__)) || return x
486+
push!(parsers, parser)
487+
push!(formatters, formatter)
488+
return :(push!($ntargs_set, $parser($(rawargs...))))
489+
end
490+
491+
function _spec_walker2(x, parsers, formatters, ntargs_set)
492+
@capture(x, StatsSpec(formatter_(parser_(rawargs__))...)) || return x
493+
push!(parsers, parser)
494+
push!(formatters, formatter)
495+
return :(push!($ntargs_set, $parser($(rawargs...))))
496+
end
497+
498+
"""
499+
@specset [option option=val ...] default_args... begin ... end
500+
@specset [option option=val ...] default_args... for v in (...) ... end
501+
@specset [option option=val ...] default_args... for v in (...), w in (...) ... end
502+
503+
Construct a vector of [`StatsSpec`](@ref) with shared default values for arguments
504+
and then conduct the procedures by calling [`proceed`](@ref).
505+
506+
# Arguments
507+
- `[option option=val ...]`: optional settings for @specset including keyword arguments for [`proceed`](@ref).
508+
- `default_args...`: optional default values for arguments shared by all [`StatsSpec`](@ref)s.
509+
- `code block`: a `begin/end` block or a `for` loop containing arguments for constructing [`StatsSpec`](@ref)s.
510+
511+
# Notes
512+
`@specset` transforms `Expr`s that construct [`StatsSpec`](@ref)
513+
to collect the sets of arguments from the code block
514+
and infers how the arguments entered by users need to be processed
515+
based on the names of functions called within [`StatsSpec`](@ref).
516+
For end users, `Macro`s that generate `Expr`s for these function calls should be provided.
517+
518+
Optional default arguments are merged
519+
with the arguments provided for each individual specification
520+
and supersede the default values specified for each procedure through [`namedargs`](@ref).
521+
These default arguments should be specified in the same pattern as
522+
how arguments are specified for each specification inside the code block,
523+
as `@specset` processes these arguments by calling
524+
the same functions found in the code block.
525+
526+
Options for the behavior of `@specset` can be provided in a bracket `[...]`
527+
as the first argument with each option separated by white space.
528+
For options that take a Boolean value,
529+
specifying the name of the option is enough for setting the value to be true.
530+
531+
The following options are available for altering the behavior of `@specset`:
532+
- `noproceed::Bool=false`: do not call [`proceed`](@ref) and return the vector of [`StatsSpec`](@ref).
533+
- `verbose::Bool=false`: print the name of each step when it is called.
534+
- `keep=nothing`: names (of type `Symbol`) of additional objects to be returned.
535+
- `keepall::Bool=false`: return all objects generated by procedures along with arguments from the [`StatsSpec`](@ref)s.
536+
"""
537+
macro specset(args...)
538+
nargs = length(args)
539+
nargs == 0 && throw(ArgumentError("no argument is found for @specset"))
540+
options = :(Dict{Symbol, Any}())
541+
noproceed = false
542+
default_args = nothing
543+
if nargs > 1
544+
if isexpr(args[1], :vect, :hcat, :vcat)
545+
noproceed = _parse!(options, args[1].args)
546+
nargs > 2 && (default_args = _args_kwargs(args[2:end-1]))
547+
else
548+
default_args = _args_kwargs(args[1:end-1])
549+
end
550+
end
551+
specs = macroexpand(__module__, args[end])
552+
isexpr(specs, :block, :for) ||
553+
throw(ArgumentError("last argument to @specset must be begin/end block or for loop"))
554+
555+
parsers, formatters, ntargs_set = Symbol[], Symbol[], NamedTuple[]
556+
walked = postwalk(x->_spec_walker1(x, parsers, formatters, ntargs_set), specs)
557+
walked = postwalk(x->_spec_walker2(x, parsers, formatters, ntargs_set), walked)
558+
nparser = length(unique!(parsers))
559+
nparser == 1 ||
560+
throw(ArgumentError("found $nparser parsers from arguments while expecting one"))
561+
nformatter = length(unique!(formatters))
562+
nformatter == 1 ||
563+
throw(ArgumentError("found $nformatter formatters from arguments while expecting one"))
564+
565+
parser, formatter = parsers[1], formatters[1]
566+
if default_args === nothing
567+
defaults = :(NamedTuple())
568+
else
569+
defaults = esc(:($parser($(default_args[1]...); $(default_args[2]...))))
570+
end
571+
572+
blk = quote
573+
$(esc(walked))
574+
local nspec = length($ntargs_set)
575+
local sps = Vector{StatsSpec}(undef, nspec)
576+
for i in 1:nspec
577+
sps[i] = StatsSpec($(esc(formatter))(merge($defaults, $ntargs_set[i]))...)
578+
end
579+
end
580+
581+
if noproceed
582+
return quote
583+
$blk
584+
sps
585+
end
586+
else
587+
return quote
588+
$blk
589+
proceed(sps; $options...)
590+
end
591+
end
592+
end

0 commit comments

Comments
 (0)