Skip to content

Commit 82a9e68

Browse files
committed
add trait tests
1 parent 7d45e08 commit 82a9e68

File tree

7 files changed

+128
-96
lines changed

7 files changed

+128
-96
lines changed

docs/src/common_implementation_patterns.md

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,30 @@
1616
This guide is intended to be consulted after reading [Anatomy of an Implementation](@ref),
1717
which introduces the main interface objects and terminology.
1818

19-
Although an implementation is defined purely by the methods and traits it implements, most
19+
Although an implementation is defined purely by the methods and traits it implements, many
2020
implementations fall into one (or more) of the following informally understood patterns or
2121
"tasks":
2222

2323
- [Regression](@ref): Supervised learners for continuous targets
2424

25-
- [Classification](@ref): Supervised learners for categorical targets
25+
- Classification: Supervised learners for categorical targets
2626

27-
- [Clusterering](@ref): Algorithms that group data into clusters for classification and
27+
- Clusterering: Algorithms that group data into clusters for classification and
2828
possibly dimension reduction. May be true learners (generalize to new data) or static.
2929

30-
- [Gradient Descent](@ref): Including neural networks.
30+
- Gradient Descent: Including neural networks.
3131

3232
- [Iterative Algorithms](@ref)
3333

34-
- [Incremental Algorithms](@ref)
34+
- Incremental Algorithms
3535

3636
- [Feature Engineering](@ref): Algorithms for selecting or combining features
3737

38-
- [Dimension Reduction](@ref): Transformers that learn to reduce feature space dimension
38+
- Dimension Reduction: Transformers that learn to reduce feature space dimension
3939

40-
- [Missing Value Imputation](@ref)
40+
- Missing Value Imputation
4141

42-
- [Transformers](@ref): Other transformers, such as standardizers, and categorical
42+
- Transformers: Other transformers, such as standardizers, and categorical
4343
encoders.
4444

4545
- [Static Algorithms](@ref): Algorithms that do not learn, in the sense they must be
@@ -48,26 +48,26 @@ implementations fall into one (or more) of the following informally understood p
4848

4949
- [Ensemble Algorithms](@ref): Algorithms that blend predictions of multiple algorithms
5050

51-
- [Time Series Forecasting](@ref)
51+
- Time Series Forecasting
5252

53-
- [Time Series Classification](@ref)
53+
- Time Series Classification
5454

55-
- [Survival Analysis](@ref)
55+
- Survival Analysis
5656

57-
- [Density Estimation](@ref): Algorithms that learn a probability distribution
57+
- Density Estimation: Algorithms that learn a probability distribution
5858

59-
- [Bayesian Algorithms](@ref)
59+
- Bayesian Algorithms
6060

61-
- [Outlier Detection](@ref): Supervised, unsupervised, or semi-supervised learners for
61+
- Outlier Detection: Supervised, unsupervised, or semi-supervised learners for
6262
anomaly detection.
6363

64-
- [Text Analysis](@ref)
64+
- Text Analysis
6565

66-
- [Audio Analysis](@ref)
66+
- Audio Analysis
6767

68-
- [Natural Language Processing](@ref)
68+
- Natural Language Processing
6969

70-
- [Image Processing](@ref)
70+
- Image Processing
7171

7272
- [Meta-algorithms](@ref)
7373

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,6 @@
11
# Meta-algorithms
2+
3+
Many meta-algorithms are wrappers. An example is [this bagged ensemble
4+
algorithm](https://github.com/JuliaAI/LearnAPI.jl/blob/dev/test/integration/iterative_algorithms.jl)
5+
from tests.
6+

docs/src/reference.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ for each.
141141
[`LearnAPI.algorithm`](@ref algorithm_minimize), [`LearnAPI.constructor`](@ref) and
142142
[`LearnAPI.functions`](@ref).
143143

144-
Most algorithms will also implement [`predict`](@ref) and/or [`transform`](@ref).
144+
Most algorithms will also implement [`predict`](@ref) and/or [`transform`](@ref). For a
145+
bare minimum implementation, see the implementation of `SmallAlgorithm`
146+
[here](https://github.com/JuliaAI/LearnAPI.jl/blob/dev/test/traits.jl).
145147

146148
### List of methods
147149

docs/src/traits.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ In the examples column of the table below, `Continuous` is a name owned the pack
2626
| [`LearnAPI.load_path`](@ref)`(algorithm)` | string locating name returned by `LearnAPI.constructor(algorithm)`, beginning with a package name | "unknown"` | `FastTrees.LearnAPI.DecisionTreeClassifier` |
2727
| [`LearnAPI.is_composite`](@ref)`(algorithm)` | `true` if one or more properties of `algorithm` may be an algorithm | `false` | `true` |
2828
| [`LearnAPI.human_name`](@ref)`(algorithm)` | human name for the algorithm; should be a noun | type name with spaces | "elastic net regressor" |
29-
| [`LearnAPI.data_interface`](@ref)`(algorithm)` | Interface implemented by objects returned by [`obs`](@ref) | `Base.HasLength()` (supports `MLUtils.getobs/numobs`) | `Base.SizeUnknown()` (supports `iterate`) |
3029
| [`LearnAPI.iteration_parameter`](@ref)`(algorithm)` | symbolic name of an iteration parameter | `nothing` | :epochs |
30+
| [`LearnAPI.data_interface`](@ref)`(algorithm)` | Interface implemented by objects returned by [`obs`](@ref) | `Base.HasLength()` (supports `MLUtils.getobs/numobs`) | `Base.SizeUnknown()` (supports `iterate`) |
3131
| [`LearnAPI.fit_observation_scitype`](@ref)`(algorithm)` | upper bound on `scitype(observation)` for `observation` in `data` ensuring `fit(algorithm, data)` works | `Union{}` | `Tuple{AbstractVector{Continuous}, Continuous}` |
3232
| [`LearnAPI.target_observation_scitype`](@ref)`(algorithm)` | upper bound on the scitype of each observation of the targget | `Any` | `Continuous` |
3333
| [`LearnAPI.predict_or_transform_mutates`](@ref)`(algorithm)` | `true` if `predict` or `transform` mutates first argument | `false` | `true` |
@@ -36,12 +36,12 @@ In the examples column of the table below, `Continuous` is a name owned the pack
3636

3737
The following are provided for convenience but should not be overloaded by new algorithms:
3838

39-
| trait | return value | example |
40-
|:-----------------------------------|:---------------------------------------------------------------------|:--------|
41-
| `LearnAPI.name(algorithm)` | algorithm type name as string | "PCA" |
42-
| `LearnAPI.is_algorithm(algorithm)` | `true` if `algorithm` is LearnAPI.jl-compliant | `true` |
43-
| `LearnAPI.target(algorithm)` | `true` if [`LearnAPI.target(algorithm, data)`](@ref) is implemented | `false` |
44-
| `LearnAPI.weights(algorithm)` | `true` if [`LearnAPI.weights(algorithm, data)`](@ref) is implemented | `false` |
39+
| trait | return value | example |
40+
|:-----------------------------------|:-------------------------------------------------------------------------|:--------|
41+
| `LearnAPI.name(algorithm)` | algorithm type name as string | "PCA" |
42+
| `LearnAPI.is_algorithm(algorithm)` | `true` if `algorithm` is LearnAPI.jl-compliant | `true` |
43+
| `LearnAPI.target(algorithm)` | `true` if `fit` sees a target variable; see [`LearnAPI.target`](@ref) | `false` |
44+
| `LearnAPI.weights(algorithm)` | `true` if `fit` supports per-observation; see [`LearnAPI.weights`](@ref) | `false` |
4545

4646
## Implementation guide
4747

src/traits.jl

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,6 @@ const DOC_EXPLAIN_EACHOBS =
2323
2424
"""
2525

26-
const TRAITS = [
27-
:constructor,
28-
:functions,
29-
:kinds_of_proxy,
30-
:tags,
31-
:is_pure_julia,
32-
:pkg_name,
33-
:pkg_license,
34-
:doc_url,
35-
:load_path,
36-
:is_composite,
37-
:human_name,
38-
:iteration_parameter,
39-
:data_interface,
40-
:predict_or_transform_mutates,
41-
:fit_observation_scitype,
42-
:target_observation_scitype,
43-
:name,
44-
:is_algorithm,
45-
:target,
46-
]
47-
48-
4926
# # OVERLOADABLE TRAITS
5027

5128
"""
@@ -426,7 +403,7 @@ variable. Specifically:
426403
variables) then "target" means anything returned by `LearnAPI.target(algorithm, data)`,
427404
where `data` is an admissible argument in the call `fit(algorithm, data)`.
428405
429-
- `S` will always be an upper bound on the scitype of observations that could be
406+
- `S` will always be an upper bound on the scitype of (point) observations that could be
430407
conceivably extracted from the output of [`predict`](@ref).
431408
432409
To illustate the second case, suppose we have

test/integration/iterative_algorithms.jl

Lines changed: 48 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,56 +7,57 @@ using Random
77
using Statistics
88
using StableRNGs
99

10-
# # ENSEMBLE OF RIDGE REGRESSORS
11-
12-
# We implement a toy algorithm that creates an bagged ensemble of ridge regressors (as
13-
# defined already in test/integration/regressors.jl), i.e, where each atomic model is
14-
# trained on a random sample of the training observations (same number, but sampled with
15-
# replacement). In particular this algorithm has an iteration parameter `n`, and we
16-
# implement `update` for warm restarts when `n` increases.
17-
18-
# no docstring here - that goes with the constructor
19-
struct RidgeEnsemble
20-
lambda::Float64
21-
rng # leaving abstract for simplicity
10+
# # ENSEMBLE OF REGRESSORS (A MODEL WRAPPER)
11+
12+
# We implement a toy algorithm that creates an bagged ensemble of regressors, i.e, where
13+
# each atomic model is trained on a random sample of the training observations (same
14+
# number, but sampled with replacement). In particular this algorithm has an iteration
15+
# parameter `n`, and we implement `update` for warm restarts when `n` increases.
16+
17+
# no docstring here - that goes with the constructor; some fields left abstract for
18+
# simplicity
19+
#
20+
struct Ensemble
21+
atom # the base regressor being bagged
22+
rng
2223
n::Int
2324
end
2425

26+
# Since the `atom` hyperparameter is another algorithm, it doesn't need a default in the
27+
# kwarg constructor, but we do need to overload the `LearnAPI.is_composite` trait (done
28+
# later).
29+
2530
"""
26-
RidgeEnsemble(; lambda=0.1, rng=Random.default_rng(), n=10)
31+
Ensemble(atom; rng=Random.default_rng(), n=10)
2732
28-
Instantiate a RidgeEnsemble algorithm, bla, bla, bla...
33+
Instantiate a bagged ensemble of `n` regressors, with base regressor `atom`, etc
2934
3035
"""
31-
RidgeEnsemble(; lambda=0.1, rng=Random.default_rng(), n=10) =
32-
RidgeEnsemble(lambda, rng, n) # LearnAPI.constructor defined later
36+
Ensemble(atom; rng=Random.default_rng(), n=10) =
37+
Ensemble(atom, rng, n) # `LearnAPI.constructor` defined later
3338

34-
struct RidgeEnsembleFitted
35-
algorithm::RidgeEnsemble
39+
struct EnsembleFitted
40+
algorithm::Ensemble
3641
atom::Ridge
3742
rng # mutated copy of `algorithm.rng`
3843
models # leaving type abstract for simplicity
3944
end
4045

41-
LearnAPI.algorithm(model::RidgeEnsembleFitted) = model.algorithm
46+
LearnAPI.algorithm(model::EnsembleFitted) = model.algorithm
4247

43-
# We add the same data interface we provided for `Ridge` in regression.jl. This is an
44-
# optional step on which the later code does not depend.
45-
LearnAPI.obs(algorithm::RidgeEnsemble, data) = LearnAPI.obs(Ridge(), data)
46-
LearnAPI.obs(model::RidgeEnsembleFitted, data) = LearnAPI.obs(first(model.models), data)
47-
LearnAPI.target(algorithm::RidgeEnsemble, data) = LearnAPI.target(Ridge(), data)
48-
LearnAPI.features(algorithm::Ridge, data) = LearnAPI.features(Ridge(), data)
48+
# We add the same data interface that the atomic regressor uses:
49+
LearnAPI.obs(algorithm::Ensemble, data) = LearnAPI.obs(algorithm.atom, data)
50+
LearnAPI.obs(model::EnsembleFitted, data) = LearnAPI.obs(first(model.models), data)
51+
LearnAPI.target(algorithm::Ensemble, data) = LearnAPI.target(algorithm.atom, data)
52+
LearnAPI.features(algorithm::Ridge, data) = LearnAPI.features(algorithm.atom, data)
4953

50-
function LearnAPI.fit(algorithm::RidgeEnsemble, data; verbosity=1)
54+
function LearnAPI.fit(algorithm::Ensemble, data; verbosity=1)
5155

5256
# unpack hyperparameters:
53-
lambda = algorithm.lambda
54-
rng = deepcopy(algorithm.rng) # to prevent mutation of `algorithm`
57+
atom = algorithm.atom
58+
rng = deepcopy(algorithm.rng) # to prevent mutation of `algorithm`!
5559
n = algorithm.n
5660

57-
# instantiate atomic algorithm:
58-
atom = Ridge(lambda)
59-
6061
# ensure data can be subsampled using MLUtils.jl, and that we're feeding the atomic
6162
# `fit` data in an efficient (pre-processed) form:
6263

@@ -80,15 +81,16 @@ function LearnAPI.fit(algorithm::RidgeEnsemble, data; verbosity=1)
8081
# make some noise, if allowed:
8182
verbosity > 0 && @info "Trained $n ridge regression models. "
8283

83-
return RidgeEnsembleFitted(algorithm, atom, rng, models)
84+
return EnsembleFitted(algorithm, atom, rng, models)
8485

8586
end
8687

87-
# If `n` is increased, this `update` adds new regressors to the ensemble, including any
88-
# new # hyperparameter updates (e.g, `lambda`) when computing the new
89-
# regressors. Otherwise, update is equivalent to retraining from scratch, with the
90-
# provided hyperparameter updates.
91-
function LearnAPI.update(model::RidgeEnsembleFitted, data; verbosity=1, replacements...)
88+
# Consistent with the documented `update` contract, we implement this behaviour: If `n` is
89+
# increased, `update` adds new regressors to the ensemble, including any new
90+
# hyperparameter updates (e.g, new `atom`) when computing the new atomic
91+
# models. Otherwise, update is equivalent to retraining from scratch, with the provided
92+
# hyperparameter updates.
93+
function LearnAPI.update(model::EnsembleFitted, data; verbosity=1, replacements...)
9294
:n in keys(replacements) || return fit(model, data)
9395

9496
algorithm_old = LearnAPI.algorithm(model)
@@ -97,7 +99,7 @@ function LearnAPI.update(model::RidgeEnsembleFitted, data; verbosity=1, replacem
9799
Δn = n - algorithm_old.n
98100
n < 0 && return fit(model, algorithm)
99101

100-
atom = Ridge(; lambda=algorithm.lambda)
102+
atom = algorithm.atom
101103
observations = obs(atom, data)
102104
N = MLUtils.numobs(observations)
103105

@@ -116,15 +118,15 @@ function LearnAPI.update(model::RidgeEnsembleFitted, data; verbosity=1, replacem
116118
# make some noise, if allowed:
117119
verbosity > 0 && @info "Trained $Δn additional ridge regression models. "
118120

119-
return RidgeEnsembleFitted(algorithm, atom, rng, models)
121+
return EnsembleFitted(algorithm, atom, rng, models)
120122
end
121123

122-
LearnAPI.predict(model::RidgeEnsembleFitted, ::Point, data) =
124+
LearnAPI.predict(model::EnsembleFitted, ::Point, data) =
123125
mean(model.models) do atomic_model
124126
predict(atomic_model, Point(), data)
125127
end
126128

127-
LearnAPI.minimize(model::RidgeEnsembleFitted) = RidgeEnsembleFitted(
129+
LearnAPI.minimize(model::EnsembleFitted) = EnsembleFitted(
128130
model.algorithm,
129131
model.atom,
130132
model.rng,
@@ -133,9 +135,10 @@ LearnAPI.minimize(model::RidgeEnsembleFitted) = RidgeEnsembleFitted(
133135

134136
# note the inclusion of `iteration_parameter`:
135137
@trait(
136-
RidgeEnsemble,
137-
constructor = RidgeEnsemble,
138+
Ensemble,
139+
constructor = Ensemble,
138140
iteration_parameter = :n,
141+
is_composite = true,
139142
kinds_of_proxy = (Point(),),
140143
tags = ("regression", "ensemble algorithms", "iterative models"),
141144
functions = (
@@ -165,7 +168,8 @@ Xtest = Tables.subset(X, test)
165168

166169
@testset "test an implementation of bagged ensemble of ridge regressors" begin
167170
rng = StableRNG(123)
168-
algorithm = RidgeEnsemble(lambda=0.5, n=4; rng)
171+
atom = Ridge()
172+
algorithm = Ensemble(atom; n=4, rng)
169173
@test LearnAPI.clone(algorithm) == algorithm
170174
@test :(LearnAPI.obs) in LearnAPI.functions(algorithm)
171175
@test LearnAPI.target(algorithm, data) == y
@@ -190,7 +194,6 @@ Xtest = Tables.subset(X, test)
190194
# compare with cold restart:
191195
model = fit(LearnAPI.clone(algorithm; n=7), Xtrain, y[train]; verbosity=0);
192196
@test ŷ7 predict(model, Xtest)
193-
194197
end
195198

196199
true

test/traits.jl

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,51 @@
1-
module FruitSalad
1+
using Test
22
using LearnAPI
33

4+
# A MINIMUM IMPLEMENTATION OF AN ALGORITHM
5+
6+
# does nothing useful
7+
struct SmallAlgorithm end
8+
LearnAPI.fit(algorithm::SmallAlgorithm, data; verbosity=1) = algorithm
9+
LearnAPI.algorithm(algorithm::SmallAlgorithm) = algorithm
10+
@trait(
11+
SmallAlgorithm,
12+
constructor = SmallAlgorithm,
13+
functions = (
14+
:(LearnAPI.fit),
15+
:(LearnAPI.algorithm),
16+
),
17+
)
18+
######## END OF IMPLEMENTATION ##################
19+
20+
# ZERO ARGUMENT METHODS
21+
22+
@test :(LearnAPI.fit) in LearnAPI.functions()
23+
@test Point in LearnAPI.kinds_of_proxy()
24+
@test "regression" in LearnAPI.tags()
25+
26+
# OVERLOADABLE TRAITS
27+
28+
small = SmallAlgorithm()
29+
@test !LearnAPI.is_pure_julia(small)
30+
@test LearnAPI.pkg_name(small) == "unknown"
31+
@test LearnAPI.pkg_license(small) == "unknown"
32+
@test LearnAPI.load_path(small) == "unknown"
33+
@test !LearnAPI.is_composite(small)
34+
@test LearnAPI.human_name(small) == "small algorithm"
35+
@test isnothing(LearnAPI.iteration_parameter(small))
36+
@test LearnAPI.data_interface(small) == LearnAPI.RandomAccess()
37+
@test !(6 isa LearnAPI.fit_observation_scitype(small))
38+
@test 6 isa LearnAPI.target_observation_scitype(small)
39+
40+
# DERIVED TRAITS
41+
42+
@test LearnAPI.is_algorithm(small)
43+
@test !LearnAPI.target(small)
44+
@test !LearnAPI.weights(small)
45+
46+
module FruitSalad
47+
import LearnAPI
48+
449
struct RedApple{T}
550
x::T
651
end

0 commit comments

Comments
 (0)