Skip to content

Commit af27b91

Browse files
committed
adapt to LearnAPI v2.0 API changes
1 parent eece8c3 commit af27b91

13 files changed

+111
-65
lines changed

src/learners/ensembling.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,8 +490,9 @@ function LearnAPI.update(
490490

491491
end
492492

493-
# needed, because model is supervised:
494-
LearnAPI.target(learner::StumpRegressor, observations) = last(observations)
493+
# training data deconstructors:
494+
LearnAPI.features(learner::StumpRegressor, data) = first(data)
495+
LearnAPI.target(learner::StumpRegressor, data) = last(data)
495496

496497
LearnAPI.predict(model::StumpRegressorFitted, ::Point, x) =
497498
_predict(model.forest, x)

src/learners/gradient_descent.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# THIS FILE IS NOT INCLUDED BUT EXISTS AS AN IMPLEMENTATION EXEMPLAR
22

3+
# TODO: This file should be updated after release of CategoricalDistributions 0.2 as
4+
# `classes` is deprecated there.
5+
36
# This file defines:
47
# - `PerceptronClassifier(; epochs=50, optimiser=Optimisers.Adam(), rng=Random.default_rng())
58

src/learners/incremental_algorithms.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,28 +67,28 @@ function LearnAPI.update_observations(model::NormalEstimatorFitted, ynew; verbos
6767
return NormalEstimatorFitted(Σy, ȳ, ss, n)
6868
end
6969

70-
LearnAPI.features(::NormalEstimator, y) = nothing
7170
LearnAPI.target(::NormalEstimator, y) = y
7271

73-
LearnAPI.predict(model::NormalEstimatorFitted, ::SingleDistribution) =
72+
LearnAPI.predict(model::NormalEstimatorFitted, ::Distribution) =
7473
Distributions.Normal(model.ȳ, sqrt(model.ss/model.n))
7574
LearnAPI.predict(model::NormalEstimatorFitted, ::Point) = model.
7675
function LearnAPI.predict(model::NormalEstimatorFitted, ::ConfidenceInterval)
77-
d = predict(model, SingleDistribution())
76+
d = predict(model, Distribution())
7877
return (quantile(d, 0.025), quantile(d, 0.975))
7978
end
8079

8180
# for fit and predict in one line:
8281
LearnAPI.predict(::NormalEstimator, k::LearnAPI.KindOfProxy, y) =
8382
predict(fit(NormalEstimator(), y), k)
84-
LearnAPI.predict(::NormalEstimator, y) = predict(NormalEstimator(), SingleDistribution(), y)
83+
LearnAPI.predict(::NormalEstimator, y) = predict(NormalEstimator(), Distribution(), y)
8584

8685
LearnAPI.extras(model::NormalEstimatorFitted) ==model.ȳ, σ=sqrt(model.ss/model.n))
8786

8887
@trait(
8988
NormalEstimator,
9089
constructor = NormalEstimator,
91-
kinds_of_proxy = (SingleDistribution(), Point(), ConfidenceInterval()),
90+
kind_of = LearnAPI.Generative(),
91+
kinds_of_proxy = (Distribution(), Point(), ConfidenceInterval()),
9292
tags = ("density estimation", "incremental algorithms"),
9393
is_pure_julia = true,
9494
human_name = "normal distribution estimator",
@@ -98,7 +98,6 @@ LearnAPI.extras(model::NormalEstimatorFitted) = (μ=model.ȳ, σ=sqrt(model.ss/
9898
:(LearnAPI.clone),
9999
:(LearnAPI.strip),
100100
:(LearnAPI.obs),
101-
:(LearnAPI.features),
102101
:(LearnAPI.target),
103102
:(LearnAPI.predict),
104103
:(LearnAPI.update_observations),

src/learners/regression.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file defines:
22

3-
# - `Ridge(; lambda=0.1)`
4-
# - `BabyRidge(; lambda=0.1)`
3+
# - `Ridge(; lambda=0.1)` (uses canned data front end)
4+
# - `BabyRidge(; lambda=0.1)` (no data front end)
55

66
using LearnAPI
77
using Tables
@@ -150,12 +150,17 @@ end
150150

151151
LearnAPI.learner(model::BabyRidgeFitted) = model.learner
152152

153+
# training data deconstructors:
154+
LearnAPI.features(learner::BabyRidge, (X, y)) = X
155+
LearnAPI.target(learner::BabyRidge, (X, y)) = y
156+
153157
LearnAPI.predict(model::BabyRidgeFitted, ::Point, Xnew) =
154158
Tables.matrix(Xnew)*model.coefficients
155159

156160
LearnAPI.strip(model::BabyRidgeFitted) =
157161
BabyRidgeFitted(model.learner, model.coefficients, nothing)
158162

163+
159164
@trait(
160165
BabyRidge,
161166
constructor = BabyRidge,

src/learners/static_algorithms.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ function LearnAPI.transform(learner::Selector, X)
5555
transform(model, X)
5656
end
5757

58-
# note the necessity of overloading `is_static` (`fit` consumes no data):
58+
# note the necessity of overloading `kind_of` (because `fit` consumes no data):
5959
@trait(
6060
Selector,
6161
constructor = Selector,
62+
kind_of = LearnAPI.Static(),
6263
tags = ("feature engineering",),
63-
is_static = true,
6464
functions = (
6565
:(LearnAPI.fit),
6666
:(LearnAPI.learner),
@@ -127,11 +127,11 @@ function LearnAPI.transform(learner::FancySelector, X)
127127
transform(model, X)
128128
end
129129

130-
# note the necessity of overloading `is_static` (`fit` consumes no data):
130+
# note the necessity of overloading `kind_of` (because `fit` consumes no data):
131131
@trait(
132132
FancySelector,
133133
constructor = FancySelector,
134-
is_static = true,
134+
kind_of = LearnAPI.Static(),
135135
tags = ("feature engineering",),
136136
functions = (
137137
:(LearnAPI.fit),

src/logging.jl

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ const QUIET = "- specify `verbosity=1` if debugging"
88

99
const CONSTRUCTOR = """
1010
11-
Testing that learner can be reconstructed from its constructors.
11+
Testing that learner can be reconstructed from its constructor.
1212
[Reference](https://juliaai.github.io/LearnAPI.jl/dev/reference/#learners).
1313
1414
"""
@@ -34,23 +34,53 @@ const FUNCTIONS = """
3434
"""
3535
const ERR_MISSINNG_OBLIGATORIES =
3636
"These obligatory functions are missing from the return value of "*
37-
"`LearnAPI.functions(learner)`; "
37+
"`LearnAPI.functions(learner)`: "
38+
39+
40+
41+
const DECONSTRUCTORS = """
42+
43+
Checking that the data deconsructors (`features`, `target` and `weights`) have only
44+
been implemented where appropriate, and looking for clues that some desirable
45+
implementations have been forgotten.
46+
47+
"""
48+
const WARN_GENERATIVE_NO_TARGET = """
49+
50+
Typically, when `LearnAPI.kind_of(learner)==LearnAPI.Generative()`, we expect
51+
`LearnAPI.target` to be implemented. If you have implemented it, check you added
52+
`:(LearnAPI.target)` to the tuple returned by `LearnAPI.functions(learner)`. If you
53+
intentionally left it unimplemented, ignore this warning.
54+
55+
"""
56+
const NO_DECONSTRUCTORS_FOR_STATIC = """
57+
58+
Since `LearnAPI.kind_of(learner)==LearnAPI.Static()`, we are checking that none of the
59+
following are in the tuple returned by `LearnAPI.functions(learner)`:
60+
`:(LearnAPI.features)`, `:(LearnAPI.target)`, `:(LearnAPI.weights)`, because there
61+
is never training data to deconstruct.
3862
39-
const FUNCTIONS3 = """
63+
"""
64+
const WARN_DESCRIMINATIVE_NO_FEATURES = """
4065
41-
Testing that `LearnAPI.functions(learner)` includes `:(LearnAPI.features).`
66+
Typically, when `LearnAPI.kind_of(learner)==LearnAPI.Descriminative()`, we expect
67+
`LearnAPI.features` to be implemented. If you have implemented it, check you added
68+
`:(LearnAPI.features)` to the tuple returned by `LearnAPI.functions(learner)`. If you
69+
intentionally left it unimplemented, ignore this warning.
4270
4371
"""
44-
const FUNCTIONS4 = """
72+
const WARN_DESCRIMINATIVE_NO_TARGET = """
4573
46-
Testing that `LearnAPI.functions(learner)` excludes `:(LearnAPI.features)`, as
47-
`LearnAPI.is_static(learner)` is `true`.
74+
Frequently, when `LearnAPI.kind_of(learner)==LearnAPI.Descriminative()`,
75+
`LearnAPI.target` is also implemented. If you have implemented it, check you added
76+
`:(LearnAPI.features)` to the tuple returned by `LearnAPI.functions(learner)`. If you
77+
intentionally left it unimplemented, ignore this warning.
4878
4979
"""
5080
const TAGS = """
5181
5282
Testing that `LearnAPI.tags(learner)` has correct form. List allowed tags with
53-
`LearnAPII.tags()`.
83+
`LearnAPI.tags()`.
5484
5585
"""
5686
const KINDS_OF_PROXY = """
@@ -62,14 +92,14 @@ const KINDS_OF_PROXY = """
6292
"""
6393
const FIT_IS_STATIC = """
6494
65-
`LearnAPI.is_static(learner)` is `true`. Therefore attempting to call
66-
`fit(learner)`.
95+
`LearnAPI.kind_of(learner)==LearnAPI.Static()`. Therefore attempting to define
96+
`model = fit(learner)`.
6797
6898
"""
6999
const FIT_IS_NOT_STATIC = """
70100
71-
Attempting to call `fit(learner, data)`. If you implemented `fit(learner)` instead,
72-
then you need to arrange `LearnAPI.is_static(learner) == true`.
101+
Attempting to define `model = fit(learner, data)`. If you implemented `fit(learner)`
102+
instead, then you need to arrange `LearnAPI.kind_of(learner)==LearnAPI.Static()`.
73103
74104
"""
75105
const LEARNER = """
@@ -86,7 +116,7 @@ const FUNCTIONS2 = """
86116
87117
"""
88118
const ERR_MISSING_FUNCTIONS =
89-
"The following overloaded functions are missing from the return value of"*
119+
"The following implemented/overloaded functions are missing from the return value of"*
90120
"`LearnAPI.functions(learner)`: "
91121

92122
const OBS = """
@@ -118,7 +148,9 @@ const PREDICT_HAS_NO_FEATURES = """
118148
119149
Attempting to call `predict(model, kind)` for each `kind` in
120150
`LearnAPI.kinds_of_proxy(learner)`. (We are not providing `predict` with a data
121-
argument because `features(obs(learner, data)) == nothing`).
151+
argument because either `LearnAPI.kind_of(learner)==LearnAPI.Generative()`, or because
152+
we presume `LearnAPI.features` is not implemented, as `:(LearnAPI.features)` is not in
153+
the return value of LearnAPI.functions(learner)`.)
122154
123155
124156
"""
@@ -256,14 +288,12 @@ const TRANSFORM_ON_SELECTIONS2 = """
256288
"""
257289
const TARGET0 = """
258290
259-
Attempting to call `LearnAPI.target(learner, data)` (fallback returns
260-
`last(data)`).
291+
Attempting to call `LearnAPI.target(learner, data)`.
261292
262293
"""
263294
const TARGET = """
264295
265-
Attempting to call `LearnAPI.target(learner, observations)` (fallback returns
266-
`last(observations)`).
296+
Attempting to call `LearnAPI.target(learner, observations)`
267297
268298
"""
269299
const TARGET_SELECTIONS = """
@@ -311,7 +341,7 @@ const UPDATE = """
311341
"""
312342
const ERR_STATIC_UPDATE = ErrorException(
313343
"`(LearnAPI.update)` is in `LearnAPI.functions(learner)` but "*
314-
"`LearnAPI.is_static(learner)` is `true`. You cannot implement `update` "*
344+
"`LearnAPI.kind_of(learner)==LearnAPI.Static()`. You cannot implement `update` "*
315345
"for static learners. "
316346
)
317347
const UPDATE_ITERATIONS = """

src/testapi.jl

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,17 @@ hyperparameter settings, are explicitly tested.
6060
6161
Each `dataset` is used as follows.
6262
63-
If `LearnAPI.is_static(learner) == false`, then:
63+
Assuming [`LearnAPI.kind_of(learner)`](@ref) returns [`LearnAPI.Descriminative()`](@ref)
64+
or [`LearnAPI.Generative()`](@ref):
6465
6566
- `dataset` is passed to `fit` and, if necessary, its `update` cousins
6667
67-
- If `X = LearnAPI.features(learner, dataset) == nothing`, then `predict` and/or
68-
`transform` are called with no data. Otherwise, they are called with `X`.
68+
- In the `Generative()` case, `predict` and/or `transform` are called without a data
69+
argument; in the `Descriminative()` case these methods are called with the data argument
70+
` X = LearnAPI.features(learner, dataset)`, assuming `:features in
71+
LearnAPI.functions(learner)`, and are otherwise not called.
6972
70-
If instead `LearnAPI.is_static(learner) == true`, then `fit` and its cousins are called
73+
If instead `LearnAPI.kind_of(learner) == Static()`, then `fit` and its cousins are called
7174
without any data, and `dataset` is passed directly to `predict` and/or `transform`.
7275
7376
"""
@@ -87,7 +90,8 @@ macro testapi(learner, data...)
8790
verbosity=$verbosity
8891
_human_name = LearnAPI.human_name(learner)
8992
_data_interface = LearnAPI.data_interface(learner)
90-
_is_static = LearnAPI.is_static(learner)
93+
_is_static = LearnAPI.kind_of(learner) == LearnAPI.Static()
94+
_is_generative = LearnAPI.kind_of(learner) == LearnAPI.Generative()
9195

9296
if isnothing(verbosity) || verbosity > 0
9397
@info "------ running @testapi - $_human_name "*$LOUD
@@ -124,13 +128,21 @@ macro testapi(learner, data...)
124128
end
125129
end
126130

127-
if !_is_static
128-
@logged_testset $FUNCTIONS3 verbosity begin
129-
Test.@test :(LearnAPI.features) in _functions
130-
end
131-
else
132-
@logged_testset $FUNCTIONS4 verbosity begin
133-
Test.@test !(:(LearnAPI.features) in _functions)
131+
_has_features = :(LearnAPI.features) in _functions
132+
_has_target = :(LearnAPI.target) in _functions
133+
_has_weights = :(LearnAPI.weights) in _functions
134+
135+
@logged_testset $DECONSTRUCTORS verbosity begin
136+
if _is_generative
137+
!_has_target && verbosity > 0 && @warn $WARN_GENERATIVE_NO_TARGET
138+
Test.@test !_has_features
139+
elseif _is_static
140+
@logged_testset $NO_DECONSTRUCTORS_FOR_STATIC verbosity begin
141+
Test.@test !_has_target && !_has_features && !_has_weights
142+
end
143+
else
144+
!_has_features && verbosity > 0 && @warn $WARN_DESCRIMINATIVE_NO_FEATURES
145+
!_has_features && verbosity > 0 && @warn $WARN_DESCRIMINATIVE_NO_TARGET
134146
end
135147
end
136148

@@ -199,13 +211,15 @@ macro testapi(learner, data...)
199211

200212
X = if _is_static
201213
data
202-
else
214+
elseif _has_features
203215
@logged_testset $FEATURES0 verbosity begin
204216
LearnAPI.features(learner, data)
205217
end
206218
@logged_testset $FEATURES verbosity begin
207219
LearnAPI.features(learner, observations)
208220
end
221+
else
222+
nothing
209223
end
210224

211225
if !(isnothing(X))
@@ -467,25 +481,18 @@ macro testapi(learner, data...)
467481

468482
# weights
469483

470-
_w = @logged_testset $WEIGHTS verbosity begin
471-
LearnAPI.weights(learner, observations)
472-
end
473-
474-
if !(isnothing(_w))
475-
@logged_testset $WEIGHTS_IN_FUNCTIONS verbosity begin
476-
Test.@test :(LearnAPI.weights) in _functions
484+
if :(LearnAPI.weights) in _functions
485+
_w = @logged_testset $WEIGHTS verbosity begin
486+
LearnAPI.weights(learner, observations)
477487
end
488+
478489
w = @logged_testset $WEIGHTS_SELECTIONS verbosity begin
479490
LearnTestAPI.learner_get(
480491
learner,
481492
data,
482493
data->LearnAPI.weights(learner, data),
483494
)
484495
end
485-
else
486-
@logged_testset $WEIGHTS_NOT_IN_FUNCTIONS verbosity begin
487-
Test.@test !(:(LearnAPI.weights) in _functions)
488-
end
489496
end
490497

491498
# update
@@ -612,7 +619,7 @@ macro testapi(learner, data...)
612619

613620
# traits
614621
# `constructor`, `functions`, `kinds_of_proxy`, `tags`, `nonlearners`,
615-
# `iteration_parameter`, `data_interface`, `is_static` tested already above
622+
# `iteration_parameter`, `data_interface`, `kind_of` tested already above
616623

617624
@logged_testset $PKG_NAME verbosity begin
618625
pkg_name = LearnAPI.pkg_name(learner)
@@ -686,3 +693,4 @@ macro testapi(learner, data...)
686693
nothing
687694
end # quote
688695
end # macro
696+

test/learners/classification.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ f = @formula(t ~ c + a)
4545
# # TESTS
4646

4747
learner = LearnTestAPI.ConstantClassifier()
48-
@testapi learner (X1, y)
48+
@testapi learner (X1, y) verbosity=0
4949
@testapi learner (X2, y) (X3, y) (X4, y) (T1, :t) (T2, :t) (T3, f) (T4, f) verbosity=0
5050

5151
@testset "extra tests for constant classifier" begin

test/learners/dimension_reduction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ U, Vt = r.U, r.Vt
1212
X = U*diagm([1, 2, 3, 0.01, 0.01])*Vt
1313

1414
learner = LearnTestAPI.TruncatedSVD(codim=2)
15-
@testapi learner X verbosity=1
15+
@testapi learner X verbosity=0
1616

1717
@testset "extra test for truncated SVD" begin
1818
model = @test_logs(

0 commit comments

Comments
 (0)