Skip to content

Commit 84ff503

Browse files
authored
Merge pull request #766 from JuliaAI/rh/std
Add std to show for `PerformanceEvaluation`
2 parents 1f33881 + 5dbd187 commit 84ff503

File tree

2 files changed

+44
-10
lines changed

2 files changed

+44
-10
lines changed

src/resampling.jl

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -462,12 +462,19 @@ outlier detection model.
462462
When `evaluate`/`evaluate!` is called, a number of train/test pairs
463463
("folds") of row indices are generated, according to the options
464464
provided, which are discussed in the [`evaluate!`](@ref)
465-
doc-string. Rows correspond to observations. The train/test pairs
466-
generated are recorded in the `train_test_rows` field of the
465+
doc-string. Rows correspond to observations. The generated train/test
466+
pairs are recorded in the `train_test_rows` field of the
467467
`PerformanceEvaluation` struct, and the corresponding estimates,
468468
aggregated over all train/test pairs, are recorded in `measurement`, a
469469
vector with one entry for each measure (metric) recorded in `measure`.
470470
471+
When displayed, a `PerformanceEvalution` object includes a value under
472+
the heading `1.96*SE`, derived from the standard error of the `per_fold`
473+
entries. This value is suitable for constructing a formal 95%
474+
confidence interval for the given `measurement`. Such intervals should
475+
be interpreted with caution. See, for example, Bates et al.
476+
[(2021)](https://arxiv.org/abs/2104.00673).
477+
471478
### Fields
472479
473480
These fields are part of the public API of the `PerformanceEvaluation`
@@ -503,8 +510,9 @@ struct.
503510
machine `mach` training in resampling - one machine per train/test
504511
pair.
505512
506-
- `train_test_rows`: a vector of tuples, each of the form `(train, test)`, where `train` and `test`
507-
are vectors of row (observation) indices for training and evaluation respectively.
513+
- `train_test_rows`: a vector of tuples, each of the form `(train, test)`,
514+
where `train` and `test` are vectors of row (observation) indices for
515+
training and evaluation respectively.
508516
"""
509517
struct PerformanceEvaluation{M,
510518
Measurement,
@@ -532,18 +540,35 @@ _short(v::Vector{<:Real}) = MLJBase.short_string(v)
532540
_short(v::Vector) = string("[", join(_short.(v), ", "), "]")
533541
_short(::Missing) = missing
534542

535-
function Base.show(io::IO, ::MIME"text/plain", e::PerformanceEvaluation)
536-
_measure = map(e.measure) do m
537-
repr(MIME("text/plain"), m)
543+
function _standard_errors(e::PerformanceEvaluation)
544+
factor = 1.96 # For the 95% confidence interval.
545+
measure = e.measure
546+
nfolds = length(e.per_fold[1])
547+
nfolds == 1 && return [nothing]
548+
std_errors = map(e.per_fold) do per_fold
549+
factor * std(per_fold) / sqrt(nfolds - 1)
538550
end
551+
return std_errors
552+
end
553+
554+
function Base.show(io::IO, ::MIME"text/plain", e::PerformanceEvaluation)
555+
_measure = [repr(MIME("text/plain"), m) for m in e.measure]
539556
_measurement = round3.(e.measurement)
540557
_per_fold = [round3.(v) for v in e.per_fold]
558+
_sterr = round3.(_standard_errors(e))
559+
560+
# Only show the standard error if the number of folds is higher than 1.
561+
show_sterr = any(!isnothing, _sterr)
562+
data = show_sterr ?
563+
hcat(_measure, e.operation, _measurement, _sterr, _per_fold) :
564+
hcat(_measure, e.operation, _measurement, _per_fold)
565+
header = show_sterr ?
566+
["measure", "operation", "measurement", "1.96*SE", "per_fold"] :
567+
["measure", "operation", "measurement", "per_fold"]
541568

542-
data = hcat(_measure, _measurement, e.operation, _per_fold)
543-
header = ["measure", "measurement", "operation", "per_fold"]
544569
println(io, "PerformanceEvaluation object "*
545570
"with these fields:")
546-
println(io, " measure, measurement, operation, per_fold,\n"*
571+
println(io, " measure, operation, measurement, per_fold,\n"*
547572
" per_observation, fitted_params_per_fold,\n"*
548573
" report_per_fold, train_test_rows")
549574
println(io, "Extract:")

test/resampling.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,12 +775,21 @@ end
775775
@test T <: PerformanceEvaluation
776776

777777
show_text = sprint(show, MIME"text/plain"(), evaluations)
778+
cols = ["measure", "operation", "measurement", "1.96*SE", "per_fold"]
779+
@test all(contains.(show_text, cols))
780+
print(show_text)
778781
docstring_text = string(@doc(PerformanceEvaluation))
779782
for fieldname in fieldnames(PerformanceEvaluation)
780783
@test contains(show_text, string(fieldname))
781784
# string(text::Markdown.MD) converts `-` list items to `*`.
782785
@test contains(docstring_text, " * `$fieldname`")
783786
end
787+
788+
measures = [LogLoss(), Accuracy()]
789+
evaluations = evaluate(clf, X, y; measures, resampling=Holdout())
790+
show_text = sprint(show, MIME"text/plain"(), evaluations)
791+
print(show_text)
792+
@test !contains(show_text, "std")
784793
end
785794

786795
#end

0 commit comments

Comments
 (0)