Skip to content

Commit 6da8531

Browse files
committed
dump predict_or_transform_mutates in favour of is_static
1 parent 82a9e68 commit 6da8531

File tree

8 files changed

+78
-47
lines changed

8 files changed

+78
-47
lines changed

docs/src/fit_update.md

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ Data slurping forms are similarly provided for updating methods.
2626

2727
## Typical workflows
2828

29+
### Supervised models
30+
2931
Supposing `Algorithm` is some supervised classifier type, with an iteration parameter `n`:
3032

3133
```julia
@@ -43,15 +45,32 @@ model = update(model; n=150)
4345
predict(model, Distribution(), X)
4446
```
4547

46-
### A static algorithm (no "learning")
48+
### Tranformers
49+
50+
A dimension-reducing transformer, `algorithm` might be used in this way:
51+
52+
```julia
53+
model = fit(algorithm, X)
54+
transform(model, X) # or `transform(model, Xnew)`
55+
```
56+
57+
or, if implemented, using a single call:
58+
59+
```julia
60+
transform(algorithm, X) # `fit` implied
61+
```
62+
63+
### Static algorithms (no "learning")
64+
65+
Suppose `algorithm` is some clustering algorithm that cannot be generalized to new data
66+
(e.g. DBSCAN):
4767

4868
```julia
49-
# Apply some clustering algorithm which cannot be generalized to new data:
5069
model = fit(algorithm) # no training data
51-
labels = predict(model, LabelAmbiguous(), X) # may mutate `model`
70+
labels = predict(model, X) # may mutate `model`
5271

5372
# Or, in one line:
54-
labels = predict(algorithm, LabelAmbiguous(), X)
73+
labels = predict(algorithm, X)
5574

5675
# But two-line version exposes byproducts of the clustering algorithm (e.g., outliers):
5776
LearnAPI.extras(model)

docs/src/traits.md

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,24 @@ training). They may also record more mundane information, such as a package lice
1313
In the examples column of the table below, `Continuous` is a name owned the package
1414
[ScientificTypesBase.jl](https://github.com/JuliaAI/ScientificTypesBase.jl/).
1515

16-
| trait | return value | fallback value | example |
17-
|:-------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------|:-----------------------------------------------------------|
18-
| [`LearnAPI.constructor`](@ref)`(algorithm)` | constructor for generating new or modified versions of `algorithm` | (no fallback) | `RidgeRegressor` |
19-
| [`LearnAPI.functions`](@ref)`(algorithm)` | functions you can apply to `algorithm` or associated model (traits excluded) | `()` | `(:fit, :predict, :minimize, :(LearnAPI.algorithm), :obs)` |
20-
| [`LearnAPI.kinds_of_proxy`](@ref)`(algorithm)` | instances `kind` of `KindOfProxy` for which an implementation of `LearnAPI.predict(algorithm, kind, ...)` is guaranteed. | `()` | `(Distribution(), Interval())` |
21-
| [`LearnAPI.tags`](@ref)`(algorithm)` | lists one or more suggestive algorithm tags from `LearnAPI.tags()` | `()` | (:regression, :probabilistic) |
22-
| [`LearnAPI.is_pure_julia`](@ref)`(algorithm)` | `true` if implementation is 100% Julia code | `false` | `true` |
23-
| [`LearnAPI.pkg_name`](@ref)`(algorithm)` | name of package providing core code (may be different from package providing LearnAPI.jl implementation) | `"unknown"` | `"DecisionTree"` |
24-
| [`LearnAPI.pkg_license`](@ref)`(algorithm)` | name of license of package providing core code | `"unknown"` | `"MIT"` |
25-
| [`LearnAPI.doc_url`](@ref)`(algorithm)` | url providing documentation of the core code | `"unknown"` | `"https://en.wikipedia.org/wiki/Decision_tree_learning"` |
26-
| [`LearnAPI.load_path`](@ref)`(algorithm)` | string locating name returned by `LearnAPI.constructor(algorithm)`, beginning with a package name | "unknown"` | `FastTrees.LearnAPI.DecisionTreeClassifier` |
27-
| [`LearnAPI.is_composite`](@ref)`(algorithm)` | `true` if one or more properties of `algorithm` may be an algorithm | `false` | `true` |
28-
| [`LearnAPI.human_name`](@ref)`(algorithm)` | human name for the algorithm; should be a noun | type name with spaces | "elastic net regressor" |
29-
| [`LearnAPI.iteration_parameter`](@ref)`(algorithm)` | symbolic name of an iteration parameter | `nothing` | :epochs |
30-
| [`LearnAPI.data_interface`](@ref)`(algorithm)` | Interface implemented by objects returned by [`obs`](@ref) | `Base.HasLength()` (supports `MLUtils.getobs/numobs`) | `Base.SizeUnknown()` (supports `iterate`) |
31-
| [`LearnAPI.fit_observation_scitype`](@ref)`(algorithm)` | upper bound on `scitype(observation)` for `observation` in `data` ensuring `fit(algorithm, data)` works | `Union{}` | `Tuple{AbstractVector{Continuous}, Continuous}` |
32-
| [`LearnAPI.target_observation_scitype`](@ref)`(algorithm)` | upper bound on the scitype of each observation of the targget | `Any` | `Continuous` |
33-
| [`LearnAPI.predict_or_transform_mutates`](@ref)`(algorithm)` | `true` if `predict` or `transform` mutates first argument | `false` | `true` |
16+
| trait | return value | fallback value | example |
17+
|:-----------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------|:-----------------------------------------------------------|
18+
| [`LearnAPI.constructor`](@ref)`(algorithm)` | constructor for generating new or modified versions of `algorithm` | (no fallback) | `RidgeRegressor` |
19+
| [`LearnAPI.functions`](@ref)`(algorithm)` | functions you can apply to `algorithm` or associated model (traits excluded) | `()` | `(:fit, :predict, :minimize, :(LearnAPI.algorithm), :obs)` |
20+
| [`LearnAPI.kinds_of_proxy`](@ref)`(algorithm)` | instances `kind` of `KindOfProxy` for which an implementation of `LearnAPI.predict(algorithm, kind, ...)` is guaranteed. | `()` | `(Distribution(), Interval())` |
21+
| [`LearnAPI.tags`](@ref)`(algorithm)` | lists one or more suggestive algorithm tags from `LearnAPI.tags()` | `()` | (:regression, :probabilistic) |
22+
| [`LearnAPI.is_pure_julia`](@ref)`(algorithm)` | `true` if implementation is 100% Julia code | `false` | `true` |
23+
| [`LearnAPI.pkg_name`](@ref)`(algorithm)` | name of package providing core code (may be different from package providing LearnAPI.jl implementation) | `"unknown"` | `"DecisionTree"` |
24+
| [`LearnAPI.pkg_license`](@ref)`(algorithm)` | name of license of package providing core code | `"unknown"` | `"MIT"` |
25+
| [`LearnAPI.doc_url`](@ref)`(algorithm)` | url providing documentation of the core code | `"unknown"` | `"https://en.wikipedia.org/wiki/Decision_tree_learning"` |
26+
| [`LearnAPI.load_path`](@ref)`(algorithm)` | string locating name returned by `LearnAPI.constructor(algorithm)`, beginning with a package name | "unknown"` | `FastTrees.LearnAPI.DecisionTreeClassifier` |
27+
| [`LearnAPI.is_composite`](@ref)`(algorithm)` | `true` if one or more properties of `algorithm` may be an algorithm | `false` | `true` |
28+
| [`LearnAPI.human_name`](@ref)`(algorithm)` | human name for the algorithm; should be a noun | type name with spaces | "elastic net regressor" |
29+
| [`LearnAPI.iteration_parameter`](@ref)`(algorithm)` | symbolic name of an iteration parameter | `nothing` | :epochs |
30+
| [`LearnAPI.data_interface`](@ref)`(algorithm)` | Interface implemented by objects returned by [`obs`](@ref) | `Base.HasLength()` (supports `MLUtils.getobs/numobs`) | `Base.SizeUnknown()` (supports `iterate`) |
31+
| [`LearnAPI.fit_observation_scitype`](@ref)`(algorithm)` | upper bound on `scitype(observation)` for `observation` in `data` ensuring `fit(algorithm, data)` works | `Union{}` | `Tuple{AbstractVector{Continuous}, Continuous}` |
32+
| [`LearnAPI.target_observation_scitype`](@ref)`(algorithm)` | upper bound on the scitype of each observation of the targget | `Any` | `Continuous` |
33+
| [`LearnAPI.is_static`](@ref)`(algorithm)` | `true` if `fit` consumes no data | `false` | `true` |
3434

3535
### Derived Traits
3636

@@ -104,5 +104,5 @@ LearnAPI.data_interface
104104
LearnAPI.iteration_parameter
105105
LearnAPI.fit_observation_scitype
106106
LearnAPI.target_observation_scitype
107-
LearnAPI.predict_or_transform_mutates
107+
LearnAPI.is_static
108108
```

src/clone.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Return a shallow copy of `algorithm` with the specified hyperparameter replaceme
77
clone(algorithm; epochs=100, learning_rate=0.01)
88
```
99
10-
It is guaranted that `LearnAPI.clone(algorithm) == algorithm`.
10+
It is guaranteed that `LearnAPI.clone(algorithm) == algorithm`.
1111
1212
"""
1313
function clone(algorithm; replacements...)

src/fit_update.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ returning an object, `model`, on which other methods, such as [`predict`](@ref)
1010
list of methods that can be applied to either `algorithm` or `model`.
1111
1212
The second signature is provided by algorithms that do not generalize to new observations
13-
("static" algorithms). In that case, `transform(model, data)` or `predict(model, ...,
14-
data)` carries out the actual algorithm execution, writing any byproducts of that
13+
(called *static algorithms*). In that case, `transform(model, data)` or `predict(model,
14+
..., data)` carries out the actual algorithm execution, writing any byproducts of that
1515
operation to the mutable object `model` returned by `fit`.
1616
1717
Whenever `fit` expects a tuple form of argument, `data = (X1, ..., Xn)`, then the
@@ -33,14 +33,16 @@ See also [`predict`](@ref), [`transform`](@ref), [`inverse_transform`](@ref),
3333
3434
# New implementations
3535
36-
Implementation is compulsory. The signature must include `verbosity`. A fallback for the
37-
first signature calls the second, ignoring `data`:
36+
Implementation is compulsory. The signature must include `verbosity`. Fallbacks provide
37+
the data slurping versions. A fallback for the first signature calls the second, ignoring
38+
`data`:
3839
3940
```julia
4041
fit(algorithm, data; kwargs...) = fit(algorithm; kwargs...)
4142
```
4243
43-
Fallbacks also provide the data slurping versions.
44+
If only the `fit(algorithm)` signature is expliclty implemented, then the trait
45+
[`LearnAPI.is_static`](@ref) must be overloaded to return `true`.
4446
4547
$(DOC_DATA_INTERFACE(:fit))
4648

src/predict_transform.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,9 @@ const DOC_OPERATIONS_LIST_FUNCTION = join(map(op -> "`LearnAPI.$op`", OPERATIONS
1111
DOC_MUTATION(op) =
1212
"""
1313
14-
If [`LearnAPI.predict_or_transform_mutates(algorithm)`](@ref) is overloaded to return
15-
`true`, then `$op` may mutate it's first argument, but not in a way that alters the
16-
result of a subsequent call to `predict`, `transform` or
17-
`inverse_transform`. This is necessary for some non-generalizing algorithms but is
18-
otherwise discouraged. See more at [`fit`](@ref).
14+
If [`LearnAPI.is_static(algorithm)`](@ref) is `true`, then `$op` may mutate it's first
15+
argument, but not in a way that alters the result of a subsequent call to `predict`,
16+
`transform` or `inverse_transform`. See more at [`fit`](@ref).
1917
2018
"""
2119

@@ -86,7 +84,7 @@ If `predict` supports data in the form of a tuple `data = (X1, ..., Xn)`, then a
8684
signature is also provided, as in `predict(model, X1, ..., Xn)`.
8785
8886
Note `predict ` does not mutate any argument, except in the special case
89-
`LearnAPI.predict_or_transform_mutates(algorithm) = true`.
87+
`LearnAPI.is_static(algorithm) == true`.
9088
9189
# New implementations
9290
@@ -150,7 +148,7 @@ W = transform(algorithm, X)
150148
```
151149
152150
Note `transform` does not mutate any argument, except in the special case
153-
`LearnAPI.predict_or_transform_mutates(algorithm) = true`.
151+
`LearnAPI.is_static(algorithm) == true`.
154152
155153
See also [`fit`](@ref), [`predict`](@ref),
156154
[`inverse_transform`](@ref).

src/target_weights_features.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,4 @@ return `nothing`.
7070
features(algorithm, data) = _first(data)
7171
_first(data) = data
7272
_first(data::Tuple) = first(data)
73-
# note the factoring above guards agains method ambiguities
73+
# note the factoring above guards against method ambiguities

src/traits.jl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -346,19 +346,30 @@ tables, and tuples of these. See the doc-string for details.
346346
data_interface(::Any) = LearnAPI.RandomAccess()
347347

348348
"""
349-
LearnAPI.predict_or_transform_mutates(algorithm)
349+
LearnAPI.is_static(algorithm)
350350
351-
Returns `true` if [`predict`](@ref) or [`transform`](@ref) possibly mutate their first
352-
argument, `model`, when `LearnAPI.algorithm(model) == algorithm`. If `false`, no arguments
353-
are ever mutated.
351+
Returns `true` if [`fit`](@ref) is called with no data arguments, as in
352+
`fit(algorithm)`. That is, `algorithm` does not generalize to new data, and data is only
353+
provided at the `predict` or `transform` step.
354+
355+
For example, some clustering algorithms are applied with this workflow, to label points
356+
observations in `X`:
357+
358+
```julia
359+
model = fit(algorithm) # no training data
360+
labels = predict(model, X) # may mutate `model`!
361+
362+
# extract some byproducts of the clustering algorithm (e.g., outliers):
363+
LearnAPI.extras(model)
364+
```
354365
355366
# New implementations
356367
357368
This trait, falling back to `false`, may only be overloaded when `fit` has no data
358-
arguments (`algorithm` does not generalize to new data). See more at [`fit`](@ref).
369+
arguments. See more at [`fit`](@ref).
359370
360371
"""
361-
predict_or_transform_mutates(::Any) = false
372+
is_static(::Any) = false
362373

363374
"""
364375
LearnAPI.iteration_parameter(algorithm)

test/integration/static_algorithms.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@ function LearnAPI.transform(algorithm::Selector, X)
3636
transform(model, X)
3737
end
3838

39+
# note the necessity of overloading `is_static` (`fit` consumes no data):
3940
@trait(
4041
Selector,
4142
constructor = Selector,
4243
tags = ("feature engineering",),
44+
is_static = true,
4345
functions = (
4446
:(LearnAPI.fit),
4547
:(LearnAPI.algorithm),
@@ -63,9 +65,7 @@ end
6365
# # FEATURE SELECTOR THAT REPORTS BYPRODUCTS OF SELECTION PROCESS
6466

6567
# This a variation of `Selector` above that stores the names of rejected features in the
66-
# model object, for inspection by an accessor function called `rejected`. Since
67-
# `transform(model, X)` mutates `model` in this case, we must overload the
68-
# `predict_or_transform_mutates` trait.
68+
# output of `fit`, for inspection by an accessor function called `rejected`.
6969

7070
struct Selector2
7171
names::Vector{Symbol}
@@ -101,10 +101,11 @@ function LearnAPI.transform(algorithm::Selector2, X)
101101
transform(model, X)
102102
end
103103

104+
# note the necessity of overloading `is_static` (`fit` consumes no data):
104105
@trait(
105106
Selector2,
106107
constructor = Selector2,
107-
predict_or_transform_mutates = true,
108+
is_static = true,
108109
tags = ("feature engineering",),
109110
functions = (
110111
:(LearnAPI.fit),

0 commit comments

Comments
 (0)