Skip to content
This repository was archived by the owner on Mar 11, 2022. It is now read-only.

Commit c751a3a

Browse files
committed
Fix some issues of specset
1 parent 922b28b commit c751a3a

File tree

5 files changed

+154
-57
lines changed

5 files changed

+154
-57
lines changed

src/DiffinDiffsBase.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ using Reexport
77
using StatsBase
88
@reexport using StatsModels
99
using StatsModels: TupleTerm
10-
using SplitApplyCombine: groupfind
10+
using SplitApplyCombine: groupfind, groupview
1111
using Tables: columntable, istable, rows, columns, getcolumn
12-
using TypedTables: Table, getproperty, getproperties
12+
using TypedTables: Table
1313

1414
import Base: ==, show, union
1515
import Base: eltype, firstindex, lastindex, getindex, iterate, length

src/StatsProcedures.jl

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,8 @@ while ignoring the orders.
322322
(x::StatsSpec{A1,T}, y::StatsSpec{A2,T}) where {A1,A2,T} =
323323
x.args y.args
324324

325+
_procedure(::StatsSpec{A,T}) where {A,T} = T
326+
325327
function (sp::StatsSpec{A,T})(;
326328
verbose::Bool=false, keep=nothing, keepall::Bool=false) where {A,T}
327329
args = verbose ? merge(sp.args, (verbose=true,)) : sp.args
@@ -357,37 +359,57 @@ end
357359

358360
function run_specset(sps::AbstractVector{<:StatsSpec};
359361
verbose::Bool=false, keep=nothing, keepall::Bool=false)
360-
traces = Vector{NamedTuple}(undef, length(sps))
362+
nsps = length(sps)
363+
nsps == 0 && throw(ArgumentError("expect a nonempty vector"))
364+
traces = Vector{NamedTuple}(undef, nsps)
361365
fill!(traces, NamedTuple())
362366
tb = Table(spec=sps, trace=traces)
363-
gids = groupfind(r->procedure(r.spec), tb)
364-
steps = pool((p() for p in keys(gids))...)
367+
gids = groupfind(r->_procedure(r.spec)(), tb)
368+
steps = pool((p for p in keys(gids))...)
369+
ntask_total = 0
365370
for step in steps
366-
verbose && printstyled("Running ", step, "\n", color=:green)
367-
args = _specnames(step)
368-
tras = _tracenames(step)
369-
byf = r->merge(NamedTuple{args}(r.spec.args), NamedTuple{tras}(r.trace))
371+
ntask = 0
372+
verbose && printstyled("Running ", step, "...")
373+
specn = _specnames(step)
374+
tracn = _tracenames(step)
375+
byf = r->merge(NamedTuple{specn}(r.spec.args), NamedTuple{tracn}(r.trace))
370376
taskids = vcat((gids[steps.procs[i]] for i in _sharedby(step))...)
371377
tasks = groupview(byf, view(tb, taskids))
372378
for (ins, subtb) in pairs(tasks)
373379
ret = _f(step)(ins...)
374-
for tr in subtb.trace
375-
tr = merge(tr, deepcopy(ret))
380+
ntask += 1
381+
ntask_total += 1
382+
if ret !== nothing
383+
for i in eachindex(subtb.trace)
384+
subtb.trace[i] = merge(subtb.trace[i], deepcopy(ret))
385+
end
376386
end
377387
end
388+
nprocs = length(_sharedby(step))
389+
verbose && printstyled("Finished ", ntask, ntask > 1 ? " tasks" : " task", " for ",
390+
nprocs, nprocs > 1 ? " procedures\n" : " procedure\n")
378391
end
392+
nprocs = length(steps.procs)
393+
verbose && printstyled("All steps finished (", ntask_total,
394+
ntask_total > 1 ? " tasks" : " task", " for ", nprocs,
395+
nprocs > 1 ? " procedures)\n" : " procedure)\n", bold=true, color=:green)
379396
if keepall
380397
return tb
381398
elseif keep===nothing
382399
return [r[end] for r in tb.trace]
383400
else
384-
return [(;result=r.trace[end], (kv for kv in pairs(r.trace) if kv[1] in keep)...,
385-
(kv for kv in pairs(r.spec.args) if kv[1] in keep)...) for r in tb]
401+
if keep isa Symbol
402+
keep = [keep]
403+
elseif eltype(keep) != Symbol
404+
throw(ArgumentError("expect Symbol or collections of Symbols for the value of option `keep`"))
405+
end
406+
return [(; (kv for kv in pairs(r.spec.args) if kv[1] in keep)...,
407+
(kv for (i,kv) in enumerate(pairs(r.trace)) if
408+
kv[1] in keep || i==length(r.trace))...,) for r in tb]
386409
end
387410
end
388411

389-
function parse_specset_options(args)
390-
options = :(Dict{Symbol, Any}())
412+
function _parse_kwargs!(options::Expr, args)
391413
for arg in args
392414
# Assume a symbol means the kwarg takes value true
393415
if isa(arg, Symbol)
@@ -397,19 +419,19 @@ function parse_specset_options(args)
397419
key = Expr(:quote, arg.args[1])
398420
push!(options.args, Expr(:call, :(=>), key, arg.args[2]))
399421
else
400-
throw(ArgumentError("unexpected argument $arg to @specset"))
422+
throw(ArgumentError("unexpected argument $arg"))
401423
end
402424
end
403-
return options
404425
end
405426

406-
function spec_walker(x, parsers, formatters)
407-
@capture(x, StatsSpec(formatter_(parser_(rawargs__)))(;o__)) || return x
427+
function _spec_walker(x, parsers, formatters, ntargs_set)
428+
@capture(x, StatsSpec(formatter_(parser_(rawargs__))...)(;o__)) || return x
408429
push!(parsers, parser)
409430
push!(formatters, formatter)
431+
push!(ntargs_set, esc(:($parser($(rawargs...)))))
410432
length(o) > 0 &&
411433
@warn "[options] specified for individual StatsSpec are ignored inside @specset"
412-
return :(push!(ntargs_set, $parser($(rawargs...))))
434+
return :()
413435
end
414436

415437
"""
@@ -460,39 +482,35 @@ macro specset(args...)
460482
nargs = length(args)
461483
nargs == 0 && throw(ArgumentError("no argument is found for @specset"))
462484

485+
options = :(Dict{Symbol, Any}())
463486
if nargs > 1
464487
if isexpr(args[1], :vect, :hcat, :vcat)
465-
options = parse_specset_options(args[1].args)
466-
nargs > 2 && (default_args = args[2:end-1])
488+
_parse_kwargs!(options, args[1].args)
489+
nargs > 2 && (default_args = _args_kwargs(args[2:end-1]))
467490
else
468-
default_args = args[1:end-1]
491+
default_args = _args_kwargs(args[1:end-1])
469492
end
470493
else
471-
options = :((;))
472494
default_args = nothing
473495
end
474-
475496
specs = args[end]
476497
isexpr(specs, :block, :for) ||
477498
throw(ArgumentError("last argument to @specset must be begin/end block or for loop"))
478499

479-
parsers = []
480-
formatters = []
481-
blk = postwalk(x->spec_walker(x, parsers, formatters), specs)
482-
length(parsers)==1 && length(formatters)==1 ||
500+
parsers, formatters, ntargs_set = [], [], []
501+
postwalk(x->_spec_walker(x, parsers, formatters, ntargs_set), specs)
502+
length(unique!(parsers))==1 && length(unique!(formatters))==1 ||
483503
throw(ArgumentError("exactly one parser and one formatter are allowed for the inner @specset"))
484504

485-
parser = parsers[1]
486-
formatter = formatters[1]
487-
defaults = default_args === nothing ? :((;)) : :($parser($(default_args...)))
488-
489-
return quote
490-
local default_args = $defaults
491-
local ntargs_set = NamedTuple[]
492-
$blk
493-
local nsps = length(ntargs_set)
494-
local sps_set = [StatsSpec($formatter(merge(default_args, ntargs_set[i])))
495-
for i in 1:nsps]
496-
run_specset(sps_set; $(options...))
505+
parser, formatter = parsers[1], formatters[1]
506+
if default_args === nothing
507+
defaults = :(NamedTuple())
508+
else
509+
defaults = esc(:($parser($(default_args[1]...); $(default_args[2]...))))
510+
end
511+
sps = :([])
512+
for ntargs in ntargs_set
513+
push!(sps.args, :(StatsSpec($(esc(formatter))($(esc(merge))($defaults, $ntargs))...)))
497514
end
515+
return :(run_specset($sps; $options...))
498516
end

src/did.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ The order of arguments is irrelevant.
9494
- `args... kwargs...`: a list of arguments to be processed by [`parse_didargs`](@ref) and [`valid_didargs`](@ref).
9595
"""
9696
macro didspec(exprs...)
97-
args, kwargs = args_kwargs(exprs)
97+
args, kwargs = _args_kwargs(exprs)
9898
return esc(:(didspec($(args...); $(kwargs...))))
9999
end
100100

@@ -111,8 +111,6 @@ function did(args...; verbose::Bool=false, keep=nothing, keepall::Bool=false, kw
111111
return sp(verbose=verbose, keep=keep, keepall=keepall)
112112
end
113113

114-
const parse_did_options = parse_specset_options
115-
116114
"""
117115
@did [option option=val ...] "name" args... kwargs...
118116
@@ -130,16 +128,16 @@ The options available are the same as the keyword arguments available for
130128
"""
131129
macro did(args...)
132130
nargs = length(args)
131+
options = :(Dict{Symbol, Any}())
133132
if nargs > 0 && isexpr(args[1], :vect, :hcat, :vcat)
134-
options = parse_did_options(args[1].args)
133+
_parse_kwargs!(options, args[1].args)
135134
if nargs > 1
136135
didargs = args[2:end]
137136
end
138137
else
139-
options = :(Dict{Symbol, Any}())
140138
didargs = args
141139
end
142-
dargs, dkwargs = args_kwargs(didargs)
140+
dargs, dkwargs = _args_kwargs(didargs)
143141
return esc(:(StatsSpec(valid_didargs(parse_didargs($(dargs...); $(dkwargs...)))...)(; $options...)))
144142
end
145143

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,14 @@ macro unpack(functionname)
9494
end
9595

9696
"""
97-
args_kwargs(exprs)
97+
_args_kwargs(exprs)
9898
9999
Partition a collection of expressions into two arrays
100100
such that all expressions in the second array has `head` being `:(=)`.
101101
This function is useful for separating out expressions
102102
for positional arguments and those for keyword arguments.
103103
"""
104-
function args_kwargs(exprs)
104+
function _args_kwargs(exprs)
105105
args = []
106106
kwargs = []
107107
for expr in exprs

test/StatsProcedures.jl

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
using DiffinDiffsBase: _f, _specnames, _tracenames,
2-
_sharedby, _show_args
2+
_sharedby, _show_args, _args_kwargs, _parse_kwargs!, _spec_walker, run_specset
33

4-
testvoidstep(a::String, b::String) = nothing
5-
const TestVoidStep = StatsStep{:TestVoidStep, typeof(testvoidstep), (:a, :b), ()}
4+
testvoidstep(a::String) = nothing
5+
const TestVoidStep = StatsStep{:TestVoidStep, typeof(testvoidstep), (:a,), ()}
66

7-
testregstep(a::String, b::String) = (b=a, c=a*b,)
7+
testregstep(a::String, b::String) = (c=a*b,)
88
const TestRegStep = StatsStep{:TestRegStep, typeof(testregstep), (:a, :b), ()}
99

1010
testlaststep(a::String, c::String) = (result=a*c,)
@@ -18,11 +18,11 @@ const TestInvalidStep = StatsStep{:TestInvalidStep, typeof(testinvalidstep), (:a
1818
@test sprint(show, TestVoidStep()) == "TestVoidStep"
1919
@test sprint(show, MIME("text/plain"), TestVoidStep()) == """
2020
TestVoidStep (StatsStep that calls testvoidstep):
21-
arguments from StatsSpec: (:a, :b)
21+
arguments from StatsSpec: (:a,)
2222
arguments from trace: ()"""
2323

2424
@test _f(TestVoidStep()) == testvoidstep
25-
@test _specnames(TestVoidStep()) == (:a, :b)
25+
@test _specnames(TestVoidStep()) == (:a,)
2626
@test _tracenames(TestVoidStep()) == ()
2727

2828
@test TestVoidStep()((a="a", b="b")) == (a="a", b="b")
@@ -39,7 +39,7 @@ const TestInvalidStep = StatsStep{:TestInvalidStep, typeof(testinvalidstep), (:a
3939
@test _specnames(TestRegStep()) == (:a, :b)
4040
@test _tracenames(TestRegStep()) == ()
4141

42-
@test TestRegStep()((a="a", b="b")) == (a="a", b="a", c="ab")
42+
@test TestRegStep()((a="a", b="b")) == (a="a", b="b", c="ab")
4343
end
4444

4545
@testset "TestLastStep" begin
@@ -209,7 +209,7 @@ end
209209
@test s1(keep=:a) == (a="a", result="aab")
210210
@test s1(keep=(:a,:c)) == (a="a", c="ab", result="aab")
211211
@test_throws ArgumentError s1(keep=1)
212-
@test s1(keepall=true) == (a="a", b="a", c="ab", result="aab")
212+
@test s1(keepall=true) == (a="a", b="b", c="ab", result="aab")
213213

214214
s6 = StatsSpec("", NP, NamedTuple())
215215
@test s6() === nothing
@@ -221,3 +221,84 @@ end
221221

222222
@test _show_args(stdout, s1) === nothing
223223
end
224+
225+
function testparser(args...; kwargs...)
226+
pargs = Pair{Symbol,Any}[kwargs...]
227+
for arg in args
228+
if arg isa Type{<:AbstractStatsProcedure}
229+
push!(pargs, :p=>arg)
230+
end
231+
end
232+
return (; pargs...)
233+
end
234+
235+
testformatter(nt::NamedTuple) = (haskey(nt, :name) ? nt.name : "", nt.p, (a=nt.a, b=nt.b))
236+
237+
@testset "specset" begin
238+
@testset "run_specset" begin
239+
s1 = StatsSpec("s1", RP, (a="a",b="b"))
240+
s2 = StatsSpec("s2", RP, (a="a",b="b"))
241+
s3 = StatsSpec("s3", RP, (a="a",b="b1"))
242+
s4 = StatsSpec("s4", UP, (a="a",b="b"))
243+
s5 = StatsSpec("s5", IP, (a="a",b="b"))
244+
245+
@test run_specset([s1]) == ["aab"]
246+
@test run_specset([s1,s2], verbose=true) == ["aab", "aab"]
247+
@test run_specset([s1,s3], verbose=true) == ["aab", "aab1"]
248+
@test run_specset([s1,s4], verbose=true) == ["aab", "ab"]
249+
@test run_specset([s1,s5], verbose=true) == ["aab", "aab"]
250+
@test run_specset([s1,s4,s5], verbose=true) == ["aab", "ab", "aab"]
251+
@test_throws ArgumentError run_specset(StatsSpec[])
252+
253+
@test run_specset([s1], keep=:a) == [(a="a",result="aab")]
254+
@test run_specset([s1], keep=[:a,:b]) == [(a="a", b="b", result="aab")]
255+
@test run_specset([s1], keep=(:d,)) == [(result="aab",)]
256+
@test run_specset([s1,s4], keep=[:a, :result]) ==
257+
[(a="a", result="aab"), (a="a", c="ab")]
258+
259+
end
260+
261+
@testset "_parse_kwargs!" begin
262+
options = :(Dict{Symbol, Any}())
263+
_parse_kwargs!(options, [:(a), :(b=1)])
264+
@test eval(options) == Dict{Symbol, Any}(:a => true, :b => 1)
265+
@test_throws ArgumentError _parse_kwargs!(options, [1])
266+
end
267+
268+
@testset "@specset" begin
269+
r = @specset a="a0" begin
270+
StatsSpec(testformatter(testparser(RP; a="a1", b="b"))...)(;) end
271+
@test r == ["a1a1b"]
272+
273+
r = @specset a="a0" begin
274+
StatsSpec(testformatter(testparser(RP; b="b"))...)(;) end
275+
@test r == ["a0a0b"]
276+
277+
r = @specset a="a0" b="b0" begin
278+
StatsSpec(testformatter(testparser(RP))...)(;)
279+
StatsSpec(testformatter(testparser(RP; a="a1", b="b1"))...)(;)
280+
end
281+
@test r == ["a0a0b0", "a1a1b1"]
282+
283+
r = @specset [verbose] a="a0" b="b0" begin
284+
StatsSpec(testformatter(testparser(RP))...)(;)
285+
StatsSpec(testformatter(testparser(RP; a="a1", c="c"))...)(;)
286+
end
287+
@test r == ["a0a0b0", "a1a1b0"]
288+
289+
a = "a0"
290+
r = @specset [verbose] a=a begin
291+
StatsSpec(testformatter(testparser(RP; b="b"))...)(;) end
292+
@test r == ["a0a0b"]
293+
294+
r = []
295+
for i in 1:3
296+
a = "a"*string(i)
297+
push!(r, @specset [verbose] a=a begin
298+
StatsSpec(testformatter(testparser(RP; b="b"))...)(;)
299+
StatsSpec(testformatter(testparser(RP; b="b1"))...)(;)
300+
end)
301+
end
302+
@test r == [["a1a1b", "a1a1b1"], ["a2a2b", "a2a2b1"], ["a3a3b", "a3a3b1"]]
303+
end
304+
end

0 commit comments

Comments
 (0)