Skip to content

Commit 8e8123a

Browse files
committed
doc improvements
1 parent 07c815e commit 8e8123a

12 files changed

+81
-44
lines changed

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ makedocs(
1818
"fit/update" => "fit_update.md",
1919
"predict/transform" => "predict_transform.md",
2020
"Kinds of Target Proxy" => "kinds_of_target_proxy.md",
21-
"obs" => "obs.md",
21+
"obs and Data Interfaces" => "obs.md",
2222
"target/weights/features" => "target_weights_features.md",
2323
"Accessor Functions" => "accessor_functions.md",
2424
"Learner Traits" => "traits.md",

docs/src/anatomy_of_an_implementation.md

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Anatomy of an Implementation
22

3-
This section explains a detailed implementation of the LearnAPI.jl for naive [ridge
3+
This tutorial details an implementation of the LearnAPI.jl for naive [ridge
44
regression](https://en.wikipedia.org/wiki/Ridge_regression) with no intercept. The kind of
55
workflow we want to enable has been previewed in [Sample workflow](@ref). Readers can also
66
refer to the [demonstration](@ref workflow) of the implementation given later.
@@ -35,8 +35,7 @@ A transformer ordinarily implements `transform` instead of `predict`. For more o
3535
then an implementation must: (i) overload [`obs`](@ref) to articulate how
3636
provided data can be transformed into a form that does support
3737
this interface, as illustrated below under
38-
[Providing a separate data front end](@ref), and which may additionally
39-
enable certain performance benefits; or (ii) overload the trait
38+
[Providing a separate data front end](@ref); or (ii) overload the trait
4039
[`LearnAPI.data_interface`](@ref) to specify a more relaxed data
4140
API.
4241

@@ -62,7 +61,7 @@ nothing # hide
6261

6362
Instances of `Ridge` are *[learners](@ref learners)*, in LearnAPI.jl parlance.
6463

65-
Associated with each new type of LearnAPI.jl [learner](@ref learners) will be a keyword
64+
Associated with each new type of LearnAPI.jl learner will be a keyword
6665
argument constructor, providing default values for all properties (typically, struct
6766
fields) that are not other learners, and we must implement
6867
[`LearnAPI.constructor(learner)`](@ref), for recovering the constructor from an instance:
@@ -365,9 +364,41 @@ y = 2a - b + 3c + 0.05*rand(n)
365364
An implementation may optionally implement [`obs`](@ref), to expose to the user (or some
366365
meta-algorithm like cross-validation) the representation of input data internal to `fit`
367366
or `predict`, such as the matrix version `A` of `X` in the ridge example. That is, we may
368-
factor out of `fit` (and also `predict`) the data pre-processing step, `obs`, to expose
369-
its outcomes. These outcomes become alternative user inputs to `fit`/`predict`. To see the
370-
use of `obs` in action, see [below](@ref advanced_demo).
367+
factor out of `fit` (and also `predict`) a data pre-processing step, `obs`, to expose
368+
its outcomes. These outcomes become alternative user inputs to `fit`/`predict`.
369+
370+
In the default case, the alternative data representations will implement the MLUtils.jl
371+
`getobs/numobs` interface for observation subsampling, which is generally all a user or
372+
meta-algorithm will need, before passing the data on to `fit`/`predict` as you would the
373+
original data.
374+
375+
So, instead of the pattern
376+
377+
```julia
378+
model = fit(learner, data)
379+
predict(model, newdata)
380+
```
381+
382+
one enables the following alternative (which in any case will still work, because of a
383+
no-op `obs` fallback provided by LearnAPI.jl):
384+
385+
```julia
386+
observations = obs(learner, data) # pre-processed training data
387+
388+
# optional subsampling:
389+
observations = MLUtils.getobs(observations, train_indices)
390+
391+
model = fit(learner, observations)
392+
393+
newobservations = obs(model, newdata)
394+
395+
# optional subsampling:
396+
newobservations = MLUtils.getobs(observations, test_indices)
397+
398+
predict(model, newobservations)
399+
```
400+
401+
See also the demonstration [below](@ref advanced_demo).
371402

372403
Here we specifically wrap all the pre-processed data into single object, for which we
373404
introduce a new type:

docs/src/obs.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ import MLUtils
4747
learner = <some supervised learner>
4848

4949
data = <some data that `fit` can consume, with 30 observations>
50-
X = LearnAPI.features(learner, data)
51-
y = LearnAPI.target(learner, data)
5250

5351
train_test_folds = map([1:10, 11:20, 21:30]) do test
5452
(setdiff(1:30, test), test)
@@ -65,12 +63,14 @@ scores = map(train_test_folds) do (train, test)
6563

6664
# predict on the fold complement:
6765
if never_trained
66+
X = LearnAPI.features(learner, data)
6867
global predictobs = obs(model, X)
6968
global never_trained = false
7069
end
7170
predictobs_subset = MLUtils.getobs(predictobs, test)
7271
= predict(model, Point(), predictobs_subset)
7372

73+
y = LearnAPI.target(learner, data)
7474
return <score comparing ŷ with y[test]>
7575

7676
end
@@ -96,8 +96,8 @@ obs
9696
### [Data interfaces](@id data_interfaces)
9797

9898
New implementations must overload [`LearnAPI.data_interface(learner)`](@ref) if the
99-
output of [`obs`](@ref) does not implement [`LearnAPI.RandomAccess`](@ref). (Arrays, most
100-
tables, and all tuples thereof, implement `RandomAccess`.)
99+
output of [`obs`](@ref) does not implement [`LearnAPI.RandomAccess()`](@ref). Arrays, most
100+
tables, and all tuples thereof, implement `RandomAccess()`.
101101

102102
- [`LearnAPI.RandomAccess`](@ref) (default)
103103
- [`LearnAPI.FiniteIterable`](@ref)

docs/src/predict_transform.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ dimension using distances from the cluster centres.
8686

8787
Learners may additionally overload `transform` to apply `fit` first, using the supplied
8888
data if required, and then immediately `transform` the same data. In that case the first
89-
argument of `transform` is an *learner* instead of the output of `fit`:
89+
argument of `transform` is a *learner* instead of the output of `fit`:
9090

9191
```julia
9292
transform(learner, data) # `fit` implied

docs/src/reference.md

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,8 @@ Informally, we will sometimes use the word "model" to refer to the output of
8080
`fit(learner, ...)` (see below), something which typically *does* store learned
8181
parameters.
8282

83-
For `learner` to be a valid LearnAPI.jl learner,
84-
[`LearnAPI.constructor(learner)`](@ref) must be defined and return a keyword constructor
85-
enabling recovery of `learner` from its properties:
83+
For every `learner`, [`LearnAPI.constructor(learner)`](@ref) must return a keyword
84+
constructor enabling recovery of `learner` from its properties:
8685

8786
```julia
8887
properties = propertynames(learner)
@@ -92,7 +91,7 @@ named_properties = NamedTuple{properties}(getproperty.(Ref(learner), properties)
9291

9392
which can be tested with `@assert `[`LearnAPI.clone(learner)`](@ref)` == learner`.
9493

95-
Note that if if `learner` is an instance of a *mutable* struct, this requirement
94+
Note that if `learner` is an instance of a *mutable* struct, this requirement
9695
generally requires overloading `Base.==` for the struct.
9796

9897
!!! important
@@ -124,6 +123,13 @@ struct GradientRidgeRegressor{T<:Real}
124123
epochs::Int
125124
l2_regularization::T
126125
end
126+
127+
"""
128+
GradientRidgeRegressor(; learning_rate=0.01, epochs=10, l2_regularization=0.01)
129+
130+
Instantiate a gradient ridge regressor with the specified hyperparameters.
131+
132+
"""
127133
GradientRidgeRegressor(; learning_rate=0.01, epochs=10, l2_regularization=0.01) =
128134
GradientRidgeRegressor(learning_rate, epochs, l2_regularization)
129135
LearnAPI.constructor(::GradientRidgeRegressor) = GradientRidgeRegressor
@@ -132,9 +138,9 @@ LearnAPI.constructor(::GradientRidgeRegressor) = GradientRidgeRegressor
132138
## Documentation
133139

134140
Attach public LearnAPI.jl-related documentation for a learner to it's *constructor*,
135-
rather than to the struct defining its type. In this way, a learner can implement
136-
multiple interfaces, in addition to the LearnAPI interface, with separate document strings
137-
for each.
141+
rather than to the struct defining its type, as shown in the example above. (In this way,
142+
multiple interfaces can share a common struct, with separate document strings for each
143+
interface.)
138144

139145
## Methods
140146

src/fit_update.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ The signature `fit(learner; verbosity=...)` (no `data`) is provided by learners
2121
not generalize to new observations (called *static algorithms*). In that case,
2222
`transform(model, data)` or `predict(model, ..., data)` carries out the actual algorithm
2323
execution, writing any byproducts of that operation to the mutable object `model` returned
24-
by `fit`.
24+
by `fit`. Inspect the value of [`LearnAPI.is_static(learner)`](@ref) to determine whether
25+
`fit` consumes `data` or not.
2526
2627
Use `verbosity=0` for warnings only, and `-1` for silent training.
2728
@@ -117,7 +118,7 @@ learner = MyNeuralNetwork(epochs=10, learning_rate=0.01)
117118
model = fit(learner, data)
118119
119120
# train for two more epochs using new data and new learning rate:
120-
model = update_observations(model, new_data; epochs=2, learning_rate=0.1)
121+
model = update_observations(model, new_data; epochs=12, learning_rate=0.1)
121122
```
122123
123124
When following the call `fit(learner, data)`, the `update` call is semantically

src/obs.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,20 @@ model = fit(learner, data_train)
2525
ŷ = predict(model, Point(), X[101:150])
2626
```
2727
28-
Alternative, data agnostic, workflow using `obs` and the MLUtils.jl method `getobs`
29-
(assumes `LearnAPI.data_interface(learner) == RandomAccess()`):
28+
Alternative workflow using `obs` and the MLUtils.jl method `getobs` to carry out
29+
subsampling (assumes `LearnAPI.data_interface(learner) == RandomAccess()`):
3030
3131
```julia
3232
import MLUtils
33-
3433
fit_observations = obs(learner, data)
3534
model = fit(learner, MLUtils.getobs(fit_observations, 1:100))
36-
3735
predict_observations = obs(model, X)
3836
ẑ = predict(model, Point(), MLUtils.getobs(predict_observations, 101:150))
3937
@assert ẑ == ŷ
4038
```
4139
4240
See also [`LearnAPI.data_interface`](@ref).
4341
44-
4542
# Extended help
4643
4744
# New implementations

src/predict_transform.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ DOC_MUTATION(op) =
88
"""
99
1010
If [`LearnAPI.is_static(learner)`](@ref) is `true`, then `$op` may mutate it's first
11-
argument, but not in a way that alters the result of a subsequent call to `predict`,
11+
argument (to record byproducts of the computation not naturally part of the return
12+
value) but not in a way that alters the result of a subsequent call to `predict`,
1213
`transform` or `inverse_transform`. See more at [`fit`](@ref).
1314
1415
"""
@@ -82,8 +83,9 @@ See also [`fit`](@ref), [`transform`](@ref), [`inverse_transform`](@ref).
8283
8384
# Extended help
8485
85-
Note `predict ` must not mutate any argument, except in the special case
86-
`LearnAPI.is_static(learner) == true`.
86+
In the special case `LearnAPI.is_static(learner) == true`, it is possible that
87+
`predict(model, ...)` will mutate `model`, but not in a way that affects subsequent
88+
`predict` calls.
8789
8890
# New implementations
8991
@@ -147,8 +149,9 @@ or, in one step (where supported):
147149
W = transform(learner, X) # `fit` implied
148150
```
149151
150-
Note `transform` does not mutate any argument, except in the special case
151-
`LearnAPI.is_static(learner) == true`.
152+
In the special case `LearnAPI.is_static(learner) == true`, it is possible that
153+
`transform(model, ...)` will mutate `model`, but not in a way that affects subsequent
154+
`transform` calls.
152155
153156
See also [`fit`](@ref), [`predict`](@ref),
154157
[`inverse_transform`](@ref).

src/target_weights_features.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,9 @@ ŷ = predict(model, kind_of_proxy, X) # eg, `kind_of_proxy = Point()`
8080
```
8181
8282
For supervised models (i.e., where `:(LearnAPI.target) in LearnAPI.functions(learner)`)
83-
`ŷ` above is generally intended to be an approximate proxy for `LearnAPI.target(learner,
84-
data)`, the training target.
83+
`ŷ` above is generally intended to be an approximate proxy for the target variable.
8584
86-
The object `X` returned by `LearnAPI.target` has the same number of observations as
85+
The object `X` returned by `LearnAPI.features` has the same number of observations as
8786
`observations` does and is guaranteed to implement the data interface specified by
8887
[`LearnAPI.data_interface(learner)`](@ref).
8988

src/traits.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ All new implementations must implement this trait. Here's a checklist for elemen
7979
return value:
8080
8181
| expression | implementation compulsory? | include in returned tuple? |
82-
|-----------------------------------|----------------------------|------------------------------------|
82+
|:----------------------------------|:---------------------------|:-----------------------------------|
8383
| `:(LearnAPI.fit)` | yes | yes |
84-
| `:(LearnAPI.learner)` | yes | yes |
84+
| `:(LearnAPI.learner)` | yes | yes |
8585
| `:(LearnAPI.strip)` | no | yes |
8686
| `:(LearnAPI.obs)` | no | yes |
8787
| `:(LearnAPI.features)` | no | yes, unless `fit` consumes no data |

0 commit comments

Comments
 (0)