Skip to content

Commit 1e946cf

Browse files
committed
make multi changes to accssr fnctns; add clone to functions()
1 parent f4aed66 commit 1e946cf

16 files changed

+168
-86
lines changed

docs/src/accessor_functions.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ The sole argument of an accessor function is the output, `model`, of
1313
- [`LearnAPI.trees(model)`](@ref)
1414
- [`LearnAPI.feature_names(model)`](@ref)
1515
- [`LearnAPI.feature_importances(model)`](@ref)
16-
- [`LearnAPI.training_labels(model)`](@ref)
1716
- [`LearnAPI.training_losses(model)`](@ref)
18-
- [`LearnAPI.training_predictions(model)`](@ref)
17+
- [`LearnAPI.out_of_sample_losses(model)`](@ref)
18+
- [`LearnAPI.predictions(model)`](@ref)
19+
- [`LearnAPI.out_of_sample_indices(model)`](@ref)
1920
- [`LearnAPI.training_scores(model)`](@ref)
2021
- [`LearnAPI.components(model)`](@ref)
2122

@@ -42,9 +43,10 @@ LearnAPI.trees
4243
LearnAPI.feature_names
4344
LearnAPI.feature_importances
4445
LearnAPI.training_losses
45-
LearnAPI.training_predictions
46+
LearnAPI.out_of_sample_losses
47+
LearnAPI.predictions
48+
LearnAPI.out_of_sample_indices
4649
LearnAPI.training_scores
47-
LearnAPI.training_labels
4850
LearnAPI.components
4951
```
5052

docs/src/anatomy_of_an_implementation.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ nothing # hide
5050

5151
## Defining learners
5252

53-
Here's a new type whose instances specify ridge regression hyperparameters:
53+
Here's a new type whose instances specify the single ridge regression hyperparameter:
5454

5555
```@example anatomy
5656
struct Ridge{T<:Real}
@@ -280,7 +280,7 @@ nothing # hide
280280

281281
```@example anatomy
282282
learner = Ridge(lambda=0.5)
283-
foreach(println, LearnAPI.functions(learner))
283+
@functions learner
284284
```
285285

286286
Training and predicting:
@@ -344,6 +344,7 @@ LearnAPI.strip(model::RidgeFitted) =
344344
functions = (
345345
:(LearnAPI.fit),
346346
:(LearnAPI.learner),
347+
:(LearnAPI.clone),
347348
:(LearnAPI.strip),
348349
:(LearnAPI.obs),
349350
:(LearnAPI.features),

docs/src/fit_update.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ Exactly one of the following must be implemented:
111111
| method | fallback | compulsory? |
112112
|:-------------------------------------------------------------------------------------|:---------|-------------|
113113
| [`update`](@ref)`(model, data; verbosity=..., hyperparameter_updates...)` | none | no |
114-
| [`update_observations`](@ref)`(model, data; verbosity=..., hyperparameter_updates...)` | none | no |
115-
| [`update_features`](@ref)`(model, data; verbosity=..., hyperparameter_updates...)` | none | no |
114+
| [`update_observations`](@ref)`(model, new_data; verbosity=..., hyperparameter_updates...)` | none | no |
115+
| [`update_features`](@ref)`(model, new_data; verbosity=..., hyperparameter_updates...)` | none | no |
116116

117117
There are some contracts governing the behaviour of the update methods, as they relate to
118118
a previous `fit` call. Consult the document strings for details.

docs/src/index.md

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ enable the basic workflow below. In this case data is presented following the
5050
"scikit-learn" `X, y` pattern, although LearnAPI.jl supports other patterns as well.
5151

5252
```julia
53-
X = <some training features>
54-
y = <some training target>
55-
Xnew = <some test or production features>
53+
# `X` is some training features
54+
# `y` is some training target
55+
# `Xnew` is some test or production features
5656

5757
# List LearnaAPI functions implemented for `forest`:
5858
@functions forest
@@ -72,11 +72,6 @@ LearnAPI.feature_importances(model)
7272
# Slim down and otherwise prepare model for serialization:
7373
small_model = LearnAPI.strip(model)
7474
serialize("my_random_forest.jls", small_model)
75-
76-
# Recover saved model and algorithm configuration ("learner"):
77-
recovered_model = deserialize("my_random_forest.jls")
78-
@assert LearnAPI.learner(recovered_model) == forest
79-
@assert predict(recovered_model, Point(), Xnew) ==
8075
```
8176

8277
`Distribution` and `Point` are singleton types owned by LearnAPI.jl. They allow

docs/src/patterns/ensembling.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ See these examples from the JuliaTestAPI.jl test suite:
44

55
- [bagged ensembling of a regression model](https://github.com/JuliaAI/LearnTestAPI.jl/blob/dev/test/patterns/ensembling.jl)
66

7+
- [extremely randomized ensemble of decision stumps (regression)](https://github.com/JuliaAI/LearnTestAPI.jl/blob/dev/test/patterns/ensembling.jl)

docs/src/patterns/iterative_algorithms.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ See these examples from the JuliaTestAI.jl test suite:
55
- [bagged ensembling](https://github.com/JuliaAI/LearnTestAPI.jl/blob/dev/test/patterns/ensembling.jl)
66

77
- [perceptron classifier](https://github.com/JuliaAI/LearnTestAPI.jl/blob/dev/test/patterns/gradient_descent.jl)
8+
9+
- [extremely randomized ensemble of decision stumps (regression)](https://github.com/JuliaAI/LearnTestAPI.jl/blob/dev/test/patterns/ensembling.jl)

docs/src/patterns/regression.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ See these examples from the JuliaTestAPI.jl test suite:
44

55
- [ridge regression](https://github.com/JuliaAI/LearnTestAPI.jl/blob/dev/test/patterns/regression.jl)
66

7+
- [extremely randomized ensemble of decision stumps](https://github.com/JuliaAI/LearnTestAPI.jl/blob/dev/test/patterns/ensembling.jl)

docs/src/reference.md

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ see [`obs`](@ref) and [`LearnAPI.data_interface`](@ref) for details.
2727

2828
!!! note
2929

30-
In the MLUtils.jl
31-
convention, observations in tables are the rows but observations in a matrix are the
32-
columns.
30+
In the MLUtils.jl
31+
convention, observations in tables are the rows but observations in a matrix are the
32+
columns.
3333

3434
### [Hyperparameters](@id hyperparameters)
3535

@@ -96,9 +96,9 @@ generally requires overloading `Base.==` for the struct.
9696

9797
!!! important
9898

99-
No LearnAPI.jl method is permitted to mutate a learner. In particular, one should make
100-
deep copies of RNG hyperparameters before using them in a new implementation of
101-
[`fit`](@ref).
99+
No LearnAPI.jl method is permitted to mutate a learner. In particular, one should make
100+
deep copies of RNG hyperparameters before using them in a new implementation of
101+
[`fit`](@ref).
102102

103103
#### Composite learners (wrappers)
104104

@@ -119,19 +119,19 @@ Below is an example of a learner type with a valid constructor:
119119

120120
```julia
121121
struct GradientRidgeRegressor{T<:Real}
122-
learning_rate::T
123-
epochs::Int
124-
l2_regularization::T
122+
learning_rate::T
123+
epochs::Int
124+
l2_regularization::T
125125
end
126126

127127
"""
128-
GradientRidgeRegressor(; learning_rate=0.01, epochs=10, l2_regularization=0.01)
129-
128+
GradientRidgeRegressor(; learning_rate=0.01, epochs=10, l2_regularization=0.01)
129+
130130
Instantiate a gradient ridge regressor with the specified hyperparameters.
131131
132132
"""
133133
GradientRidgeRegressor(; learning_rate=0.01, epochs=10, l2_regularization=0.01) =
134-
GradientRidgeRegressor(learning_rate, epochs, l2_regularization)
134+
GradientRidgeRegressor(learning_rate, epochs, l2_regularization)
135135
LearnAPI.constructor(::GradientRidgeRegressor) = GradientRidgeRegressor
136136
```
137137

@@ -146,9 +146,9 @@ interface.)
146146

147147
!!! note "Compulsory methods"
148148

149-
All new learner types must implement [`fit`](@ref),
150-
[`LearnAPI.learner`](@ref), [`LearnAPI.constructor`](@ref) and
151-
[`LearnAPI.functions`](@ref).
149+
All new learner types must implement [`fit`](@ref),
150+
[`LearnAPI.learner`](@ref), [`LearnAPI.constructor`](@ref) and
151+
[`LearnAPI.functions`](@ref).
152152

153153
Most learners will also implement [`predict`](@ref) and/or [`transform`](@ref). For a
154154
minimal (but useless) implementation, see the implementation of `SmallLearner`
@@ -198,10 +198,14 @@ minimal (but useless) implementation, see the implementation of `SmallLearner`
198198

199199
## Utilities
200200

201+
- [`clone`](@ref): for cloning a learner with specified hyperparameter replacements.
202+
- [`@trait`](@ref): for simultaneously declaring multiple traits
203+
- [`@functions`](@ref): for listing functions available for use with a learner
204+
201205
```@docs
202-
@functions
203-
LearnAPI.clone
206+
clone
204207
@trait
208+
@functions
205209
```
206210

207211
---

src/LearnAPI.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ include("accessor_functions.jl")
1111
include("traits.jl")
1212
include("clone.jl")
1313

14-
export @trait, @functions
14+
export @trait, @functions, clone
1515
export fit, update, update_observations, update_features
1616
export predict, transform, inverse_transform, obs
1717

0 commit comments

Comments
 (0)