Skip to content

Commit d37e60e

Browse files
committed
replace is_composite trait with nonlearners trait
1 parent 1e946cf commit d37e60e

11 files changed

+166
-128
lines changed

docs/src/anatomy_of_an_implementation.md

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
11
# Anatomy of an Implementation
22

3-
This tutorial details an implementation of the LearnAPI.jl for naive [ridge
4-
regression](https://en.wikipedia.org/wiki/Ridge_regression) with no intercept. The kind of
5-
workflow we want to enable has been previewed in [Sample workflow](@ref). Readers can also
6-
refer to the [demonstration](@ref workflow) of the implementation given later.
7-
83
The core LearnAPI.jl pattern looks like this:
94

105
```julia
@@ -14,9 +9,21 @@ predict(model, newdata)
149

1510
Here `learner` specifies hyperparameters, while `model` stores learned parameters and any byproducts of algorithm execution.
1611

17-
A transformer ordinarily implements `transform` instead of `predict`. For more on
12+
[Transformers](@ref) ordinarily implement `transform` instead of `predict`. For more on
1813
`predict` versus `transform`, see [Predict or transform?](@ref)
1914

15+
["Static" algorithms](@ref static_algorithms) have a `fit` that consumes no `data`
16+
(instead `predict` or `transform` does the heavy lifting). In [density
17+
estimation](@ref density_estimation), `predict` consumes no data.
18+
19+
These are the basic possibilities.
20+
21+
Elaborating on the core pattern above, we detail in this tutorial an implementation of the
22+
LearnAPI.jl for naive [ridge regression](https://en.wikipedia.org/wiki/Ridge_regression)
23+
with no intercept. The kind of workflow we want to enable has been previewed in [Sample
24+
workflow](@ref). Readers can also refer to the [demonstration](@ref workflow) of the
25+
implementation given later.
26+
2027
!!! note
2128

2229
New implementations of `fit`, `predict`, etc,
@@ -102,7 +109,7 @@ nothing # hide
102109
Note that we also include `learner` in the struct, for it must be possible to recover
103110
`learner` from the output of `fit`; see [Accessor functions](@ref) below.
104111

105-
The core implementation of `fit` looks like this:
112+
The implementation of `fit` looks like this:
106113

107114
```@example anatomy
108115
function LearnAPI.fit(learner::Ridge, data; verbosity=LearnAPI.default_verbosity())
@@ -131,7 +138,7 @@ end
131138

132139
## Implementing `predict`
133140

134-
Users will be able to call `predict` like this:
141+
One way users will be able to call `predict` is like this:
135142

136143
```julia
137144
predict(model, Point(), Xnew)
@@ -229,6 +236,7 @@ A macro provides a shortcut, convenient when multiple traits are to be defined:
229236
functions = (
230237
:(LearnAPI.fit),
231238
:(LearnAPI.learner),
239+
:(LearnAPI.clone),
232240
:(LearnAPI.strip),
233241
:(LearnAPI.obs),
234242
:(LearnAPI.features),
@@ -241,12 +249,17 @@ nothing # hide
241249
```
242250

243251
The last trait, `functions`, returns a list of all LearnAPI.jl methods that can be
244-
meaningfully applied to the learner or associated model. See [`LearnAPI.functions`](@ref)
245-
for a checklist. [`LearnAPI.functions`](@ref) and [`LearnAPI.constructor`](@ref), are the
246-
only universally compulsory traits. However, it is worthwhile studying the [list of all
247-
traits](@ref traits_list) to see which might apply to a new implementation, to enable
248-
maximum buy into functionality provided by third party packages, and to assist third party
249-
algorithms that match machine learning algorithms to user-defined tasks.
252+
meaningfully applied to the learner or associated model. You always include the first five
253+
you see here: `fit`, `learner`, `clone` ,`strip`, `obs`. Here [`clone`](@ref) is a utility
254+
function provided by LearnAPI that you never overload; overloading [`obs`](@ref) is
255+
optional (see [Providing a separate data front end](@ref)) but it is always included
256+
because it has a fallback. See [`LearnAPI.functions`](@ref) for a checklist.
257+
258+
[`LearnAPI.functions`](@ref) and [`LearnAPI.constructor`](@ref), are the only universally
259+
compulsory traits. However, it is worthwhile studying the [list of all traits](@ref
260+
traits_list) to see which might apply to a new implementation, to enable maximum buy into
261+
functionality provided by third party packages, and to assist third party algorithms that
262+
match machine learning algorithms to user-defined tasks.
250263

251264
Note that we know `Ridge` instances are supervised learners because `:(LearnAPI.target)
252265
in LearnAPI.functions(learner)`, for every instance `learner`. With [some

docs/src/fit_update.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ LearnAPI.extras(model)
7676

7777
See also [Static Algorithms](@ref)
7878

79-
### Density estimation
79+
### [Density estimation](@id density_estimation)
8080

8181
In density estimation, `fit` consumes no features, only a target variable; `predict`,
8282
which consumes no data, returns the learned density:

docs/src/predict_transform.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ transform(model, data)
66
inverse_transform(model, data)
77
```
88

9-
Versions without the `data` argument may apply, for example in [Density
10-
estimation](@ref).
9+
Versions without the `data` argument may apply, for example in [density
10+
estimation](@ref density_estimation).
1111

1212
## [Typical worklows](@id predict_workflow)
1313

docs/src/reference.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,11 @@ generally requires overloading `Base.==` for the struct.
103103
#### Composite learners (wrappers)
104104

105105
A *composite learner* is one with at least one property that can take other learners as
106-
values; for such learners [`LearnAPI.is_composite`](@ref)`(learner)` must be `true`
107-
(fallback is `false`). Generally, the keyword constructor provided by
108-
[`LearnAPI.constructor`](@ref) must provide default values for all properties that are not
109-
learner-valued. Instead, these learner-valued properties can have a `nothing` default,
110-
with the constructor throwing an error if the constructor call does not explicitly
111-
specify a new value.
106+
values; for such learners [`LearnAPI.learners(learner)`](@ref) is non-empty. A keyword
107+
constructor provided by [`LearnAPI.constructor`](@ref) must provide default values for all
108+
properties that are not in [`LearnAPI.learners(learner)`](@ref). Instead, these
109+
learner-valued properties can have a `nothing` default, with the constructor throwing an
110+
error if the constructor call does not explicitly specify a new value.
112111

113112
Any object `learner` for which [`LearnAPI.functions(learner)`](@ref) is non-empty is
114113
understood to have a valid implementation of the LearnAPI.jl interface.

docs/src/traits.md

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,45 +13,48 @@ training). They may also record more mundane information, such as a package lice
1313
In the examples column of the table below, `Continuous` is a name owned the package
1414
[ScientificTypesBase.jl](https://github.com/JuliaAI/ScientificTypesBase.jl/).
1515

16-
| trait | return value | fallback value | example |
17-
|:-----------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------|:-----------------------------------------------------------|
18-
| [`LearnAPI.constructor`](@ref)`(learner)` | constructor for generating new or modified versions of `learner` | (no fallback) | `RidgeRegressor` |
16+
| trait | return value | fallback value | example |
17+
|:---------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------|:---------------------------------------------------------------|
18+
| [`LearnAPI.constructor`](@ref)`(learner)` | constructor for generating new or modified versions of `learner` | (no fallback) | `RidgeRegressor` |
1919
| [`LearnAPI.functions`](@ref)`(learner)` | functions you can apply to `learner` or associated model (traits excluded) | `()` | `(:fit, :predict, :LearnAPI.strip, :(LearnAPI.learner), :obs)` |
20-
| [`LearnAPI.kinds_of_proxy`](@ref)`(learner)` | instances `kind` of `KindOfProxy` for which an implementation of `LearnAPI.predict(learner, kind, ...)` is guaranteed. | `()` | `(Distribution(), Interval())` |
21-
| [`LearnAPI.tags`](@ref)`(learner)` | lists one or more suggestive learner tags from `LearnAPI.tags()` | `()` | (:regression, :probabilistic) |
22-
| [`LearnAPI.is_pure_julia`](@ref)`(learner)` | `true` if implementation is 100% Julia code | `false` | `true` |
23-
| [`LearnAPI.pkg_name`](@ref)`(learner)` | name of package providing core code (may be different from package providing LearnAPI.jl implementation) | `"unknown"` | `"DecisionTree"` |
24-
| [`LearnAPI.pkg_license`](@ref)`(learner)` | name of license of package providing core code | `"unknown"` | `"MIT"` |
25-
| [`LearnAPI.doc_url`](@ref)`(learner)` | url providing documentation of the core code | `"unknown"` | `"https://en.wikipedia.org/wiki/Decision_tree_learning"` |
26-
| [`LearnAPI.load_path`](@ref)`(learner)` | string locating name returned by `LearnAPI.constructor(learner)`, beginning with a package name | "unknown"` | `FastTrees.LearnAPI.DecisionTreeClassifier` |
27-
| [`LearnAPI.is_composite`](@ref)`(learner)` | `true` if one or more properties of `learner` may be a learner | `false` | `true` |
28-
| [`LearnAPI.human_name`](@ref)`(learner)` | human name for the learner; should be a noun | type name with spaces | "elastic net regressor" |
29-
| [`LearnAPI.iteration_parameter`](@ref)`(learner)` | symbolic name of an iteration parameter | `nothing` | :epochs |
30-
| [`LearnAPI.data_interface`](@ref)`(learner)` | Interface implemented by objects returned by [`obs`](@ref) | `Base.HasLength()` (supports `MLUtils.getobs/numobs`) | `Base.SizeUnknown()` (supports `iterate`) |
31-
| [`LearnAPI.fit_observation_scitype`](@ref)`(learner)` | upper bound on `scitype(observation)` for `observation` in `data` ensuring `fit(learner, data)` works | `Union{}` | `Tuple{AbstractVector{Continuous}, Continuous}` |
32-
| [`LearnAPI.target_observation_scitype`](@ref)`(learner)` | upper bound on the scitype of each observation of the targget | `Any` | `Continuous` |
33-
| [`LearnAPI.is_static`](@ref)`(learner)` | `true` if `fit` consumes no data | `false` | `true` |
20+
| [`LearnAPI.kinds_of_proxy`](@ref)`(learner)` | instances `kind` of `KindOfProxy` for which an implementation of `LearnAPI.predict(learner, kind, ...)` is guaranteed. | `()` | `(Distribution(), Interval())` |
21+
| [`LearnAPI.tags`](@ref)`(learner)` | lists one or more suggestive learner tags from `LearnAPI.tags()` | `()` | (:regression, :probabilistic) |
22+
| [`LearnAPI.is_pure_julia`](@ref)`(learner)` | `true` if implementation is 100% Julia code | `false` | `true` |
23+
| [`LearnAPI.pkg_name`](@ref)`(learner)` | name of package providing core code (may be different from package providing LearnAPI.jl implementation) | `"unknown"` | `"DecisionTree"` |
24+
| [`LearnAPI.pkg_license`](@ref)`(learner)` | name of license of package providing core code | `"unknown"` | `"MIT"` |
25+
| [`LearnAPI.doc_url`](@ref)`(learner)` | url providing documentation of the core code | `"unknown"` | `"https://en.wikipedia.org/wiki/Decision_tree_learning"` |
26+
| [`LearnAPI.load_path`](@ref)`(learner)` | string locating name returned by `LearnAPI.constructor(learner)`, beginning with a package name | `"unknown"` | `FastTrees.LearnAPI.DecisionTreeClassifier` |
27+
| [`LearnAPI.nonlearners`](@ref)`(learner)` | properties *not* corresponding to other learners | all properties | `(:K, :leafsize, :metric,)` |
28+
| [`LearnAPI.human_name`](@ref)`(learner)` | human name for the learner; should be a noun | type name with spaces | "elastic net regressor" |
29+
| [`LearnAPI.iteration_parameter`](@ref)`(learner)` | symbolic name of an iteration parameter | `nothing` | :epochs |
30+
| [`LearnAPI.data_interface`](@ref)`(learner)` | Interface implemented by objects returned by [`obs`](@ref) | `Base.HasLength()` (supports `MLUtils.getobs/numobs`) | `Base.SizeUnknown()` (supports `iterate`) |
31+
| [`LearnAPI.fit_observation_scitype`](@ref)`(learner)` | upper bound on `scitype(observation)` for `observation` in `data` ensuring `fit(learner, data)` works | `Union{}` | `Tuple{AbstractVector{Continuous}, Continuous}` |
32+
| [`LearnAPI.target_observation_scitype`](@ref)`(learner)` | upper bound on the scitype of each observation of the targget | `Any` | `Continuous` |
33+
| [`LearnAPI.is_static`](@ref)`(learner)` | `true` if `fit` consumes no data | `false` | `true` |
3434

3535
### Derived Traits
3636

3737
The following are provided for convenience but should not be overloaded by new learners:
3838

39-
| trait | return value | example |
40-
|:-----------------------------------|:-------------------------------------------------------------------------|:--------|
41-
| `LearnAPI.name(learner)` | learner type name as string | "PCA" |
42-
| `LearnAPI.is_learner(learner)` | `true` if `learner` is LearnAPI.jl-compliant | `true` |
43-
| `LearnAPI.target(learner)` | `true` if `fit` sees a target variable; see [`LearnAPI.target`](@ref) | `false` |
44-
| `LearnAPI.weights(learner)` | `true` if `fit` supports per-observation; see [`LearnAPI.weights`](@ref) | `false` |
39+
| trait | return value | example |
40+
|:-------------------------------|:-------------------------------------------------------------------------|:--------------|
41+
| `LearnAPI.name(learner)` | learner type name as string | "PCA" |
42+
| `LearnAPI.learners(learner)` | properties with learner values | `(:atom, )` |
43+
| `LearnAPI.is_learner(learner)` | `true` if `learner` is LearnAPI.jl-compliant | `true` |
44+
| `LearnAPI.target(learner)` | `true` if `fit` sees a target variable; see [`LearnAPI.target`](@ref) | `false` |
45+
| `LearnAPI.weights(learner)` | `true` if `fit` supports per-observation; see [`LearnAPI.weights`](@ref) | `false` |
4546

4647
## Implementation guide
4748

49+
Only `LearnAPI.constructor` and `LearnAPI.functions` are universally compulsory.
50+
4851
A single-argument trait is declared following this pattern:
4952

5053
```julia
5154
LearnAPI.is_pure_julia(learner::MyLearnerType) = true
5255
```
5356

54-
A shorthand for single-argument traits is available:
57+
A macro [`@trait`](@ref) provides a short-cut:
5558

5659
```julia
5760
@trait MyLearnerType is_pure_julia=true
@@ -75,8 +78,8 @@ requires:
7578

7679
1. *Finiteness:* The value of a trait is the same for all `learner`s with same value of
7780
[`LearnAPI.constructor(learner)`](@ref). This typically means trait values do not
78-
depend on type parameters! If `is_composite(learner) = true`, this requirement is
79-
dropped.
81+
depend on type parameters! For composite models (`LearnAPI.learners(learner)`
82+
non-empty) this requirement is dropped.
8083

8184
2. *Low level deserializability:* It should be possible to evaluate the trait *value* when
8285
`LearnAPI` is the only imported module.
@@ -98,11 +101,15 @@ LearnAPI.pkg_name
98101
LearnAPI.pkg_license
99102
LearnAPI.doc_url
100103
LearnAPI.load_path
101-
LearnAPI.is_composite
104+
LearnAPI.nonlearners
102105
LearnAPI.human_name
103106
LearnAPI.data_interface
104107
LearnAPI.iteration_parameter
105108
LearnAPI.fit_observation_scitype
106109
LearnAPI.target_observation_scitype
107110
LearnAPI.is_static
108111
```
112+
113+
```@docs
114+
LearnAPI.learners
115+
```

0 commit comments

Comments
 (0)