Skip to content

Commit 105d7ff

Browse files
committed
more doc tweaks
1 parent 1449814 commit 105d7ff

9 files changed

+83
-78
lines changed

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ makedocs(
1818
"fit/update" => "fit_update.md",
1919
"predict/transform" => "predict_transform.md",
2020
"Kinds of Target Proxy" => "kinds_of_target_proxy.md",
21-
"target/weights/features" => "target_weights_features.md",
2221
"obs" => "obs.md",
22+
"target/weights/features" => "target_weights_features.md",
2323
"Accessor Functions" => "accessor_functions.md",
2424
"Learner Traits" => "traits.md",
2525
],

docs/src/anatomy_of_an_implementation.md

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -462,21 +462,14 @@ LearnAPI.predict(model::RidgeFitted, ::Point, Xnew) =
462462

463463
### `target` and `features` methods
464464

465-
We provide an additional overloading of [`LearnAPI.target`](@ref) to handle the additional
466-
supported data argument of `fit`:
465+
In the general case, we only need to implement [`LearnAPI.target`](@ref) and
466+
[`LearnAPI.features`](@ref) to handle all possible output of `obs(learner, data)`, and now
467+
the fallback for `LearnAPI.features` mentioned before is inadequate.
467468

468469
```@example anatomy2
469470
LearnAPI.target(::Ridge, observations::RidgeFitObs) = observations.y
470-
```
471-
472-
Similarly, we must overload [`LearnAPI.features`](@ref), which extracts features from
473-
training data (objects that can be passed to `predict`) like this
474-
475-
```@example anatomy2
476471
LearnAPI.features(::Ridge, observations::RidgeFitObs) = observations.A
477472
```
478-
as the fallback mentioned above is no longer adequate.
479-
480473

481474
### Important notes:
482475

@@ -501,7 +494,8 @@ interfaces](@ref data_interfaces) for details.
501494

502495
### Addition of signatures for user convenience
503496

504-
As above, we add a signature which plays no role vis-à-vis LearnAPI.jl.
497+
As above, we add a signature for convenience, which the LearnAPI.jl specification
498+
neither requires nor forbids:
505499

506500
```@example anatomy2
507501
LearnAPI.fit(learner::Ridge, X, y; kwargs...) = fit(learner, (X, y); kwargs...)

docs/src/common_implementation_patterns.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
!!! important
44

55
This section is only an implementation guide. The definitive specification of the
6-
Learn API is given in [Reference](@ref reference).
6+
LearnAPI is given in [Reference](@ref reference).
77

88
This guide is intended to be consulted after reading [Anatomy of an Implementation](@ref),
99
which introduces the main interface objects and terminology.

docs/src/index.md

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
```@raw html
22
<script async defer src="https://buttons.github.io/buttons.js"></script>
3+
4+
<div style="font-size:1.4em;font-weight:bold;">
5+
<a href="anatomy_of_an_implementation.html"
6+
style="color: #389826;">Tutorial</a> &nbsp;|&nbsp;
7+
<a href="reference.html"
8+
style="color: #9558B2;">Reference</a> &nbsp;|&nbsp;
9+
<a href="common_implementation_patterns.html"
10+
style="color: #9558B2;">Patterns</a>
11+
</div>
12+
313
<span style="color: #9558B2;font-size:4.5em;">
414
LearnAPI.jl</span>
515
<br>
@@ -86,11 +96,11 @@ opts out. Moreover, the `fit` and `predict` methods will also be able to consume
8696
alternative data representations, for performance benefits in some situations.
8797

8898
The fallback data interface is the [MLUtils.jl](https://github.com/JuliaML/MLUtils.jl)
89-
`getobs/numobs` interface (here tagged as [`LearnAPI.RandomAccess()`](@ref)) and if the
99+
`getobs/numobs` interface, here tagged as [`LearnAPI.RandomAccess()`](@ref), and if the
90100
input consumed by the algorithm already implements that interface (tables, arrays, etc.)
91101
then overloading `obs` is completely optional. Plain iteration interfaces, with or without
92-
knowledge of the number of observations, can also be specified (to support, e.g., data
93-
loaders reading images from disk).
102+
knowledge of the number of observations, can also be specified, to support, e.g., data
103+
loaders reading images from disk.
94104

95105
## Learning more
96106

docs/src/reference.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,15 +170,15 @@ minimal (but useless) implementation, see the implementation of `SmallLearner`
170170
- [`inverse_transform`](@ref operations): for inverting the output of
171171
`transform` ("inverting" broadly understood)
172172

173-
- [`LearnAPI.target`](@ref input), [`LearnAPI.weights`](@ref input),
174-
[`LearnAPI.features`](@ref): for extracting relevant parts of training data, where
175-
defined.
176-
177173
- [`obs`](@ref data_interface): method for exposing to the user
178174
learner-specific representations of data, which are additionally guaranteed to
179175
implement the observation access API specified by
180176
[`LearnAPI.data_interface(learner)`](@ref).
181177

178+
- [`LearnAPI.target`](@ref input), [`LearnAPI.weights`](@ref input),
179+
[`LearnAPI.features`](@ref): for extracting relevant parts of training data, where
180+
defined.
181+
182182
- [Accessor functions](@ref accessor_functions): these include functions like
183183
`LearnAPI.feature_importances` and `LearnAPI.training_losses`, for extracting, from
184184
training outcomes, information common to many learners. This includes

docs/src/target_weights_features.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# [`target`, `weights`, and `features`](@id input)
22

3-
Methods for extracting parts of training data:
3+
Methods for extracting parts of training observations. Here "observations" means the
4+
output of [`obs(learner, data)`](@ref); if `obs` is not overloaded for `learner`, then
5+
"observations" is any `data` supported in calls of the form [`fit(learner, data)`](@ref)
46

57
```julia
6-
LearnAPI.target(learner, data) -> <target variable>
7-
LearnAPI.weights(learner, data) -> <per-observation weights>
8-
LearnAPI.features(learner, data) -> <training "features", suitable input for `predict` or `transform`>
8+
LearnAPI.target(learner, observations) -> <target variable>
9+
LearnAPI.weights(learner, observations) -> <per-observation weights>
10+
LearnAPI.features(learner, observations) -> <training "features", suitable input for `predict` or `transform`>
911
```
1012

1113
Here `data` is something supported in a call of the form `fit(learner, data)`.
@@ -19,7 +21,8 @@ Supposing `learner` is a supervised classifier predicting a one-dimensional vect
1921
target:
2022

2123
```julia
22-
model = fit(learner, data)
24+
observations = obs(learner, data)
25+
model = fit(learner, observations)
2326
X = LearnAPI.features(learner, data)
2427
y = LearnAPI.target(learner, data)
2528
= predict(model, Point(), X)

src/obs.jl

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ using `MLUtils.getobs`, with the obvious interpretation applying to the outcomes
6161
calls (e.g., if *all* observations are subsampled, then outcomes should be the same as if
6262
using the original data).
6363
64-
Implicit in preceding requirements is that `obs(learner, _)` and `obs(model, _)` are
65-
involutive, meaning both the following hold:
64+
It is required that `obs(learner, _)` and `obs(model, _)` are involutive, meaning both the
65+
following hold:
6666
6767
```julia
6868
obs(learner, obs(learner, data)) == obs(learner, data)
@@ -81,14 +81,6 @@ only of suitable tables and arrays, then `obs` and `LearnAPI.data_interface` do
8181
to be overloaded. However, the user will get no performance benefits by using `obs` in
8282
that case.
8383
84-
If overloading `obs(learner, data)` to output new model-specific representations of
85-
data, it may be necessary to also overload [`LearnAPI.features(learner,
86-
observations)`](@ref), [`LearnAPI.target(learner, observations)`](@ref) (supervised
87-
learners), and/or [`LearnAPI.weights(learner, observations)`](@ref) (if weights are
88-
supported), for each kind output `observations` of `obs(learner, data)`. Moreover, the
89-
outputs of these methods, applied to `observations`, must also implement the interface
90-
specified by [`LearnAPI.data_interface(learner)`](@ref).
91-
9284
## Sample implementation
9385
9486
Refer to the ["Anatomy of an

src/target_weights_features.jl

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""
2-
LearnAPI.target(learner, data) -> target
2+
LearnAPI.target(learner, observations) -> target
33
4-
Return, for each form of `data` supported in a call of the form [`fit(learner,
5-
data)`](@ref), the target variable part of `data`. If `nothing` is returned, the
4+
Return, for every conceivable `observations` returned by a call of the form [`obs(learner,
5+
data)`](@ref), the target variable part of `observations`. If `nothing` is returned, the
66
`learner` does not see a target variable in training (is unsupervised).
77
8-
The returned object `y` has the same number of observations as `data`. If `data` is the
9-
output of an [`obs`](@ref) call, then `y` is additionally guaranteed to implement the
10-
data interface specified by [`LearnAPI.data_interface(learner)`](@ref).
8+
The returned object `y` has the same number of observations as `observations` does and is
9+
guaranteed to implement the data interface specified by
10+
[`LearnAPI.data_interface(learner)`](@ref).
1111
1212
# Extended help
1313
@@ -21,57 +21,61 @@ the LearnAPI.jl documentation.
2121
2222
## New implementations
2323
24-
A fallback returns `nothing`. The method must be overloaded if `fit` consumes data
25-
including a target variable.
24+
A fallback returns `nothing`. The method must be overloaded if [`fit`](@ref) consumes data
25+
that includes a target variable. If `obs` is not being overloaded, then `observations`
26+
above is any `data` supported in calls of the form [`fit(learner, data)`](@ref). The form
27+
of the output `y` should be suitable for pairing with the output of [`predict`](@ref), in
28+
the evaluation of a loss function, for example.
2629
27-
If overloading [`obs`](@ref), ensure that the return value, unless `nothing`, implements
28-
the data interface specified by [`LearnAPI.data_interface(learner)`](@ref), in the special
29-
case that `data` is the output of an `obs` call.
30+
Ensure the object `y` returned by `LearnAPI.target`, unless `nothing`, implements the data
31+
interface specified by [`LearnAPI.data_interface(learner)`](@ref).
3032
3133
$(DOC_IMPLEMENTED_METHODS(":(LearnAPI.target)"; overloaded=true))
3234
3335
"""
34-
target(::Any, data) = nothing
36+
target(::Any, observations) = nothing
3537

3638
"""
37-
LearnAPI.weights(learner, data) -> weights
39+
LearnAPI.weights(learner, observations) -> weights
3840
39-
Return, for each form of `data` supported in a call of the form [`fit(learner,
40-
data)`](@ref), the per-observation weights part of `data`. Where `nothing` is returned, no
41-
weights are part of `data`, which is to be interpreted as uniform weighting.
41+
Return, for every conceivable `observations` returned by a call of the form [`obs(learner,
42+
data)`](@ref), the weights part of `observations`. Where `nothing` is returned, no weights
43+
are part of `data`, which is to be interpreted as uniform weighting.
4244
43-
The returned object `w` has the same number of observations as `data`. If `data` is the
44-
output of an [`obs`](@ref) call, then `w` is additionally guaranteed to implement the
45-
data interface specified by [`LearnAPI.data_interface(learner)`](@ref).
45+
The returned object `w` has the same number of observations as `observations` does and is
46+
guaranteed to implement the data interface specified by
47+
[`LearnAPI.data_interface(learner)`](@ref).
4648
4749
# Extended help
4850
4951
# New implementations
5052
51-
Overloading is optional. A fallback returns `nothing`.
53+
Overloading is optional. A fallback returns `nothing`. If `obs` is not being overloaded,
54+
then `observations` above is any `data` supported in calls of the form [`fit(learner,
55+
data)`](@ref).
5256
53-
If overloading [`obs`](@ref), ensure that the return value, unless `nothing`, implements
54-
the data interface specified by [`LearnAPI.data_interface(learner)`](@ref), in the special
55-
case that `data` is the output of an `obs` call.
57+
Ensure the returned object, unless `nothing`, implements the data interface specified by
58+
[`LearnAPI.data_interface(learner)`](@ref).
5659
5760
$(DOC_IMPLEMENTED_METHODS(":(LearnAPI.weights)"; overloaded=true))
5861
5962
"""
60-
weights(::Any, data) = nothing
63+
weights(::Any, observations) = nothing
6164

6265
"""
63-
LearnAPI.features(learner, data)
66+
LearnAPI.features(learner, observations)
6467
65-
Return, for each form of `data` supported in a call of the form [`fit(learner,
66-
data)`](@ref), the "features" part of `data` (as opposed to the target
67-
variable, for example).
68+
Return, for every conceivable `observations` returned by a call of the form [`obs(learner,
69+
data)`](@ref), the "features" part of `data` (as opposed to the target variable, for
70+
example).
6871
6972
The returned object `X` may always be passed to `predict` or `transform`, where
7073
implemented, as in the following sample workflow:
7174
7275
```julia
73-
model = fit(learner, data)
74-
X = LearnAPI.features(learner, data)
76+
observations = obs(learner, data)
77+
model = fit(learner, observations)
78+
X = LearnAPI.features(learner, observations)
7579
ŷ = predict(model, kind_of_proxy, X) # eg, `kind_of_proxy = Point()`
7680
```
7781
@@ -80,28 +84,30 @@ For supervised models (i.e., where `:(LearnAPI.target) in LearnAPI.functions(lea
8084
data)`, the training target.
8185
8286
The object `X` returned by `LearnAPI.target` has the same number of observations as
83-
`data`. If `data` is the output of an [`obs`](@ref) call, then `X` is additionally
84-
guaranteed to implement the data interface specified by
87+
`observations` does and is guaranteed to implement the data interface specified by
8588
[`LearnAPI.data_interface(learner)`](@ref).
8689
8790
# Extended help
8891
8992
# New implementations
9093
94+
A fallback returns `first(observations)` if `observations` is a tuple, and otherwise
95+
returns `observations`. New implementations may need to overload this method if this
96+
fallback is inadequate.
97+
9198
For density estimators, whose `fit` typically consumes *only* a target variable, you
92-
should overload this method to return `nothing`.
99+
should overload this method to return `nothing`. If `obs` is not being overloaded, then
100+
`observations` above is any `data` supported in calls of the form [`fit(learner,
101+
data)`](@ref).
93102
94103
It must otherwise be possible to pass the return value `X` to `predict` and/or
95-
`transform`, and `X` must have same number of observations as `data`. A fallback returns
96-
`first(data)` if `data` is a tuple, and otherwise returns `data`.
104+
`transform`, and `X` must have same number of observations as `data`.
97105
98-
Further overloadings may be necessary to handle the case that `data` is the output of
99-
[`obs(learner, data)`](@ref), if `obs` is being overloaded. In this case, be sure that
100-
`X`, unless `nothing`, implements the data interface specified by
106+
Ensure the returned object, unless `nothing`, implements the data interface specified by
101107
[`LearnAPI.data_interface(learner)`](@ref).
102108
103109
"""
104-
features(learner, data) = _first(data)
105-
_first(data) = data
106-
_first(data::Tuple) = first(data)
110+
features(learner, observations) = _first(observations)
111+
_first(observations) = observations
112+
_first(observations::Tuple) = first(observations)
107113
# note the factoring above guards against method ambiguities

src/traits.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ iteration_parameter(::Any) = nothing
387387
388388
Return an upper bound `S` on the scitype of individual observations guaranteed to work
389389
when calling `fit`: if `observations = obs(learner, data)` and
390-
`ScientificTypes.scitype(o) <:S` for each `o` in `observations`, then the call
390+
`ScientificTypes.scitype(collect(o)) <:S` for each `o` in `observations`, then the call
391391
`fit(learner, data)` is supported.
392392
393393
$DOC_EXPLAIN_EACHOBS
@@ -396,7 +396,7 @@ See also [`LearnAPI.target_observation_scitype`](@ref).
396396
397397
# New implementations
398398
399-
Optional. The fallback return value is `Union{}`.
399+
Optional. The fallback return value is `Union{}`.
400400
401401
"""
402402
fit_observation_scitype(::Any) = Union{}

0 commit comments

Comments
 (0)