Skip to content

Commit 1e504fe

Browse files
committed
replace minimize -> LearnAPI.strip
oops oops tweak
1 parent 8360ad4 commit 1e504fe

17 files changed

+100
-119
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Configure a learning algorithm, and inspect available functionality:
1818
```julia
1919
julia> algorithm = Ridge(lambda=0.1)
2020
julia> LearnAPI.functions(algorithm)
21-
(:(LearnAPI.fit), :(LearnAPI.algorithm), :(LearnAPI.minimize), :(LearnAPI.obs),
21+
(:(LearnAPI.fit), :(LearnAPI.algorithm), :(LearnAPI.strip), :(LearnAPI.obs),
2222
:(LearnAPI.features), :(LearnAPI.target), :(LearnAPI.predict), :(LearnAPI.coefficients))
2323
```
2424

docs/make.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ makedocs(
1818
"fit/update" => "fit_update.md",
1919
"predict/transform" => "predict_transform.md",
2020
"Kinds of Target Proxy" => "kinds_of_target_proxy.md",
21-
"minimize" => "minimize.md",
2221
"target/weights/features" => "target_weights_features.md",
2322
"obs" => "obs.md",
2423
"Accessor Functions" => "accessor_functions.md",

docs/src/accessor_functions.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# [Accessor Functions](@id accessor_functions)
22

33
The sole argument of an accessor function is the output, `model`, of
4-
[`fit`](@ref). Algorithms are free to implement any number of these, or none of them.
4+
[`fit`](@ref). Algorithms are free to implement any number of these, or none of them. Only
5+
`LearnAPI.strip` has a fallback, namely the identity.
56

67
- [`LearnAPI.algorithm(model)`](@ref)
78
- [`LearnAPI.extras(model)`](@ref)
9+
- [`LearnAPI.strip(model)`](@ref)
810
- [`LearnAPI.coefficients(model)`](@ref)
911
- [`LearnAPI.intercept(model)`](@ref)
1012
- [`LearnAPI.tree(model)`](@ref)
@@ -31,6 +33,7 @@ optional, any implemented accessor functions must be added to the list returned
3133
```@docs
3234
LearnAPI.algorithm
3335
LearnAPI.extras
36+
LearnAPI.strip
3437
LearnAPI.coefficients
3538
LearnAPI.intercept
3639
LearnAPI.tree

docs/src/anatomy_of_an_implementation.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,15 +183,15 @@ nothing #hide
183183

184184
## Tearing a model down for serialization
185185

186-
The `minimize` method falls back to the identity. Here, for the sake of illustration, we
186+
The `LearnAPI.strip` method falls back to the identity. Here, for the sake of illustration, we
187187
overload it to dump the named version of the coefficients:
188188

189189
```@example anatomy
190-
LearnAPI.minimize(model::RidgeFitted) =
190+
LearnAPI.strip(model::RidgeFitted) =
191191
RidgeFitted(model.algorithm, model.coefficients, nothing)
192192
```
193193

194-
Crucially, we can still use `LearnAPI.minimize(model)` in place of `model` to make new
194+
Crucially, we can still use `LearnAPI.strip(model)` in place of `model` to make new
195195
predictions.
196196

197197

@@ -220,7 +220,7 @@ A macro provides a shortcut, convenient when multiple traits are to be defined:
220220
functions = (
221221
:(LearnAPI.fit),
222222
:(LearnAPI.algorithm),
223-
:(LearnAPI.minimize),
223+
:(LearnAPI.strip),
224224
:(LearnAPI.obs),
225225
:(LearnAPI.features),
226226
:(LearnAPI.target),
@@ -285,7 +285,7 @@ Serialization/deserialization:
285285

286286
```@example anatomy
287287
using Serialization
288-
small_model = minimize(model)
288+
small_model = LearnAPI.strip(model)
289289
filename = tempname()
290290
serialize(filename, small_model)
291291
```
@@ -316,7 +316,7 @@ end
316316
317317
LearnAPI.algorithm(model::RidgeFitted) = model.algorithm
318318
LearnAPI.coefficients(model::RidgeFitted) = model.named_coefficients
319-
LearnAPI.minimize(model::RidgeFitted) =
319+
LearnAPI.strip(model::RidgeFitted) =
320320
RidgeFitted(model.algorithm, model.coefficients, nothing)
321321
322322
@trait(
@@ -327,7 +327,7 @@ LearnAPI.minimize(model::RidgeFitted) =
327327
functions = (
328328
:(LearnAPI.fit),
329329
:(LearnAPI.algorithm),
330-
:(LearnAPI.minimize),
330+
:(LearnAPI.strip),
331331
:(LearnAPI.obs),
332332
:(LearnAPI.features),
333333
:(LearnAPI.target),

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ predict(model, Distribution(), Xnew)
6060
LearnAPI.feature_importances(model)
6161

6262
# Slim down and otherwise prepare model for serialization:
63-
small_model = minimize(model)
63+
small_model = LearnAPI.strip(model)
6464
serialize("my_random_forest.jls", small_model)
6565

6666
# Recover saved model and algorithm configuration:

docs/src/minimize.md

Lines changed: 0 additions & 34 deletions
This file was deleted.

docs/src/reference.md

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ for each.
138138

139139
!!! note "Compulsory methods"
140140

141-
All new algorithm types must implement [`fit`](@ref),
142-
[`LearnAPI.algorithm`](@ref algorithm_minimize), [`LearnAPI.constructor`](@ref) and
143-
[`LearnAPI.functions`](@ref).
141+
All new algorithm types must implement [`fit`](@ref),
142+
[`LearnAPI.algorithm`](@ref), [`LearnAPI.constructor`](@ref) and
143+
[`LearnAPI.functions`](@ref).
144144

145145
Most algorithms will also implement [`predict`](@ref) and/or [`transform`](@ref). For a
146146
bare minimum implementation, see the implementation of `SmallAlgorithm`
@@ -152,10 +152,10 @@ bare minimum implementation, see the implementation of `SmallAlgorithm`
152152
for non-generalizing algorithms (see [here](@ref static_algorithms) and [Static
153153
Algorithms](@ref)), for wrapping `algorithm` in a mutable struct that can be mutated by
154154
`predict`/`transform` to record byproducts of those operations.
155-
155+
156156
- [`update`](@ref fit): for updating learning outcomes after hyperparameter changes, such
157157
as increasing an iteration parameter.
158-
158+
159159
- [`update_observations`](@ref fit), [`update_features`](@ref fit): update learning
160160
outcomes by presenting additional training data.
161161

@@ -168,9 +168,6 @@ bare minimum implementation, see the implementation of `SmallAlgorithm`
168168
- [`inverse_transform`](@ref operations): for inverting the output of
169169
`transform` ("inverting" broadly understood)
170170

171-
- [`minimize`](@ref algorithm_minimize): for stripping the `model` output by `fit` of
172-
inessential content, for purposes of serialization.
173-
174171
- [`LearnAPI.target`](@ref input), [`LearnAPI.weights`](@ref input),
175172
[`LearnAPI.features`](@ref): for extracting relevant parts of training data, where
176173
defined.
@@ -181,8 +178,10 @@ bare minimum implementation, see the implementation of `SmallAlgorithm`
181178
[`LearnAPI.data_interface(algorithm)`](@ref).
182179

183180
- [Accessor functions](@ref accessor_functions): these include functions like
184-
`feature_importances` and `training_losses`, for extracting, from training outcomes,
185-
information common to many algorithms.
181+
`LearnAPI.feature_importances` and `LearnAPI.training_losses`, for extracting, from
182+
training outcomes, information common to many algorithms. This includes
183+
[`LearnAPI.strip(model)`](@ref) for replacing a learning outcome `model` with a
184+
serializable version that can still `predict` or `transform`.
186185

187186
- [Algorithm traits](@ref traits): methods that promise specific algorithm behavior or
188187
record general information about the algorithm. Only [`LearnAPI.constructor`](@ref) and

docs/src/traits.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ In the examples column of the table below, `Continuous` is a name owned the pack
1616
| trait | return value | fallback value | example |
1717
|:-----------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------|:-----------------------------------------------------------|
1818
| [`LearnAPI.constructor`](@ref)`(algorithm)` | constructor for generating new or modified versions of `algorithm` | (no fallback) | `RidgeRegressor` |
19-
| [`LearnAPI.functions`](@ref)`(algorithm)` | functions you can apply to `algorithm` or associated model (traits excluded) | `()` | `(:fit, :predict, :minimize, :(LearnAPI.algorithm), :obs)` |
19+
| [`LearnAPI.functions`](@ref)`(algorithm)` | functions you can apply to `algorithm` or associated model (traits excluded) | `()` | `(:fit, :predict, :LearnAPI.strip, :(LearnAPI.algorithm), :obs)` |
2020
| [`LearnAPI.kinds_of_proxy`](@ref)`(algorithm)` | instances `kind` of `KindOfProxy` for which an implementation of `LearnAPI.predict(algorithm, kind, ...)` is guaranteed. | `()` | `(Distribution(), Interval())` |
2121
| [`LearnAPI.tags`](@ref)`(algorithm)` | lists one or more suggestive algorithm tags from `LearnAPI.tags()` | `()` | (:regression, :probabilistic) |
2222
| [`LearnAPI.is_pure_julia`](@ref)`(algorithm)` | `true` if implementation is 100% Julia code | `false` | `true` |

src/LearnAPI.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ include("tools.jl")
66
include("types.jl")
77
include("predict_transform.jl")
88
include("fit_update.jl")
9-
include("minimize.jl")
109
include("target_weights_features.jl")
1110
include("obs.jl")
1211
include("accessor_functions.jl")
@@ -15,7 +14,7 @@ include("clone.jl")
1514

1615
export @trait
1716
export fit, update, update_observations, update_features
18-
export predict, transform, inverse_transform, minimize, obs
17+
export predict, transform, inverse_transform, obs
1918

2019
for name in Symbol.(CONCRETE_TARGET_PROXY_TYPES_SYMBOLS)
2120
@eval export $name

src/accessor_functions.jl

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ const DOC_STATIC =
1616

1717
"""
1818
LearnAPI.algorithm(model)
19-
LearnAPI.algorithm(minimized_model)
19+
LearnAPI.algorithm(LearnAPI.stripd_model)
2020
21-
Recover the algorithm used to train `model` or the output of [`minimize(model)`](@ref).
21+
Recover the algorithm used to train `model` or the output of [`LearnAPI.strip(model)`](@ref).
2222
2323
In other words, if `model = fit(algorithm, data...)`, for some `algorithm` and `data`,
2424
then
2525
2626
```julia
27-
LearnAPI.algorithm(model) == algorithm == LearnAPI.algorithm(minimize(model))
27+
LearnAPI.algorithm(model) == algorithm == LearnAPI.algorithm(LearnAPI.strip(model))
2828
```
2929
is `true`.
3030
@@ -36,6 +36,61 @@ only contract. $(DOC_IMPLEMENTED_METHODS(":(LearnAPI.algorithm)"))
3636
"""
3737
function algorithm end
3838

39+
"""
40+
LearnAPI.strip(model; options...)
41+
42+
Return a version of `model` that will generally have a smaller memory allocation than
43+
`model`, suitable for serialization. Here `model` is any object returned by
44+
[`fit`](@ref). Accessor functions that can be called on `model` may not work on
45+
`LearnAPI.strip(model)`, but [`predict`](@ref), [`transform`](@ref) and
46+
[`inverse_transform`](@ref) will work, if implemented. Check
47+
`LearnAPI.functions(LearnAPI.algorithm(model))` to view see what the original `model`
48+
implements.
49+
50+
Specific algorithms may provide keyword `options` to control how much of the original
51+
functionality is preserved by `LearnAPI.strip`.
52+
53+
# Typical workflow
54+
55+
```julia
56+
model = fit(algorithm, (X, y)) # or `fit(algorithm, X, y)`
57+
ŷ = predict(model, Point(), Xnew)
58+
59+
small_model = LearnAPI.strip(model)
60+
serialize("my_model.jls", small_model)
61+
62+
recovered_model = deserialize("my_random_forest.jls")
63+
@assert predict(recovered_model, Point(), Xnew) == ŷ
64+
```
65+
66+
# Extended help
67+
68+
# New implementations
69+
70+
Overloading `LearnAPI.strip` for new algorithms is optional. The fallback is the
71+
identity.
72+
73+
New implementations must enforce the following identities, whenever the right-hand side is
74+
defined:
75+
76+
```julia
77+
predict(LearnAPI.strip(model; options...), args...; kwargs...) ==
78+
predict(model, args...; kwargs...)
79+
transform(LearnAPI.strip(model; options...), args...; kwargs...) ==
80+
transform(model, args...; kwargs...)
81+
inverse_transform(LearnAPI.strip(model; options), args...; kwargs...) ==
82+
inverse_transform(model, args...; kwargs...)
83+
```
84+
85+
Additionally:
86+
87+
```julia
88+
LearnAPI.strip(LearnAPI.strip(model)) == LearnAPI.strip(model)
89+
```
90+
91+
"""
92+
LearnAPI.strip(model) = model
93+
3994
"""
4095
LearnAPI.feature_importances(model)
4196

0 commit comments

Comments
 (0)