Skip to content

Commit 0d66698

Browse files
committed
add data intrfce rqrmnt on output of features, target, weights
1 parent 69fbfeb commit 0d66698

File tree

6 files changed

+64
-32
lines changed

6 files changed

+64
-32
lines changed

docs/src/anatomy_of_an_implementation.md

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ A transformer ordinarily implements `transform` instead of `predict`. For more o
3535
then an implementation must: (i) overload [`obs`](@ref) to articulate how
3636
provided data can be transformed into a form that does support
3737
this interface, as illustrated below under
38-
[Providing an advanced data interface](@ref), and which may additionally
38+
[Providing a separate data front end](@ref), and which may additionally
3939
enable certain performance benefits; or (ii) overload the trait
4040
[`LearnAPI.data_interface`](@ref) to specify a more relaxed data
4141
API.
@@ -314,7 +314,7 @@ recovered_model = deserialize(filename)
314314
@assert predict(recovered_model, X) == predict(model, X)
315315
```
316316

317-
## Providing an advanced data interface
317+
## Providing a separate data front end
318318

319319
```@setup anatomy2
320320
using LearnAPI
@@ -364,9 +364,13 @@ y = 2a - b + 3c + 0.05*rand(n)
364364

365365
An implementation may optionally implement [`obs`](@ref), to expose to the user (or some
366366
meta-algorithm like cross-validation) the representation of input data internal to `fit`
367-
or `predict`, such as the matrix version `A` of `X` in the ridge example. Here we
368-
specifically wrap all the pre-processed data into single object, for which we introduce a
369-
new type:
367+
or `predict`, such as the matrix version `A` of `X` in the ridge example. That is, we may
368+
factor out of `fit` (and also `predict`) the data pre-processing step, `obs`, to expose
369+
its outcomes. These outcomes become alternative user inputs to `fit`. To see the use of
370+
`obs` in action, see [below](@ref advanced_demo).
371+
372+
Here we specifically wrap all the pre-processed data into single object, for which we
373+
introduce a new type:
370374

371375
```@example anatomy2
372376
struct RidgeFitObs{T,M<:AbstractMatrix{T}}
@@ -503,7 +507,7 @@ As above, we add a signature which plays no role vis-à-vis LearnAPI.jl.
503507
LearnAPI.fit(learner::Ridge, X, y; kwargs...) = fit(learner, (X, y); kwargs...)
504508
```
505509

506-
## Demonstration of an advanced `obs` workflow
510+
## [Demonstration of an advanced `obs` workflow](@id advanced_demo)
507511

508512
We now can train and predict using internal data representations, resampled using the
509513
generic MLUtils.jl interface:

docs/src/obs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ end
8383
| [`obs(model, data)`](@ref) | here `data` is `predict`-consumable | not typically | returns `data` |
8484

8585

86-
A sample implementation is given in [Providing an advanced data interface](@ref).
86+
A sample implementation is given in [Providing a separate data front end](@ref).
8787

8888

8989
## Reference

docs/src/reference.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ ML/statistical algorithms are typically applied in conjunction with resampling o
1616
*observations*, as in
1717
[cross-validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics)). In this
1818
document *data* will always refer to objects encapsulating an ordered sequence of
19-
individual observations. If a learner is trained using multiple data objects, it is
20-
undertood that individual objects share the same number of observations, and that
21-
resampling of one component implies synchronized resampling of the others.
19+
individual observations.
2220

2321
A `DataFrame` instance, from [DataFrames.jl](https://dataframes.juliadata.org/stable/), is
2422
an example of data, the observations being the rows. Typically, data provided to
@@ -97,9 +95,11 @@ which can be tested with `@assert `[`LearnAPI.clone(learner)`](@ref)` == learner
9795
Note that if if `learner` is an instance of a *mutable* struct, this requirement
9896
generally requires overloading `Base.==` for the struct.
9997

100-
No LearnAPI.jl method is permitted to mutate a learner. In particular, one should make
101-
deep copies of RNG hyperparameters before using them in a new implementation of
102-
[`fit`](@ref).
98+
!!! important
99+
100+
No LearnAPI.jl method is permitted to mutate a learner. In particular, one should make
101+
deep copies of RNG hyperparameters before using them in a new implementation of
102+
[`fit`](@ref).
103103

104104
#### Composite learners (wrappers)
105105

@@ -145,7 +145,7 @@ for each.
145145
[`LearnAPI.functions`](@ref).
146146

147147
Most learners will also implement [`predict`](@ref) and/or [`transform`](@ref). For a
148-
bare minimum implementation, see the implementation of `SmallLearner`
148+
minimal (but useless) implementation, see the implementation of `SmallLearner`
149149
[here](https://github.com/JuliaAI/LearnAPI.jl/blob/dev/test/traits.jl).
150150

151151
### List of methods

src/fit_update.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ model = fit(learner, (X, y))
1717
ŷ = predict(model, Xnew)
1818
```
1919
20-
The second signature, with `data` omitted, is provided by learners that do not
20+
The signature `fit(learner; verbosity=1)` (no `data`) is provided by learners that do not
2121
generalize to new observations (called *static algorithms*). In that case,
2222
`transform(model, data)` or `predict(model, ..., data)` carries out the actual algorithm
2323
execution, writing any byproducts of that operation to the mutable object `model` returned

src/obs.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,13 @@ only of suitable tables and arrays, then `obs` and `LearnAPI.data_interface` do
7777
to be overloaded. However, the user will get no performance benefits by using `obs` in
7878
that case.
7979
80-
When overloading `obs(learner, data)` to output new model-specific representations of
80+
If overloading `obs(learner, data)` to output new model-specific representations of
8181
data, it may be necessary to also overload [`LearnAPI.features(learner,
8282
observations)`](@ref), [`LearnAPI.target(learner, observations)`](@ref) (supervised
8383
learners), and/or [`LearnAPI.weights(learner, observations)`](@ref) (if weights are
84-
supported), for each kind output `observations` of `obs(learner, data)`.
84+
supported), for each kind output `observations` of `obs(learner, data)`. Moreover, the
85+
outputs of these methods, applied to `observations`, must also implement the interface
86+
specfied by [`LearnAPI.data_interface(learner)`](@ref).
8587
8688
## Sample implementation
8789

src/target_weights_features.jl

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,28 @@ Return, for each form of `data` supported in a call of the form [`fit(learner,
55
data)`](@ref), the target variable part of `data`. If `nothing` is returned, the
66
`learner` does not see a target variable in training (is unsupervised).
77
8+
The returned object `y` has the same number of observations as `data`. If `data` is the
9+
output of an [`obs`](@ref) call, then `y` is additionally guaranteed to implement the
10+
data interface specified by [`LearnAPI.data_interface(learner)`](@ref).
11+
812
# Extended help
913
1014
## What is a target variable?
1115
12-
Examples of target variables are house prices in realestate pricing estimates, the
16+
Examples of target variables are house prices in real estate pricing estimates, the
1317
"spam"/"not spam" labels in an email spam filtering task, "outlier"/"inlier" labels in
1418
outlier detection, cluster labels in clustering problems, and censored survival times in
1519
survival analysis. For more on targets and target proxies, see the "Reference" section of
1620
the LearnAPI.jl documentation.
1721
1822
## New implementations
1923
20-
A fallback returns `nothing`. Must be implemented if `fit` consumes data including a
21-
target variable.
24+
A fallback returns `nothing`. The method must be overloaded if `fit` consumes data
25+
including a target variable.
26+
27+
If overloading [`obs`](@ref), ensure that the return value, unless `nothing`, implements
28+
the data interface specified by [`LearnAPI.data_interface(learner)`](@ref), in the special
29+
case that `data` is the output of an `obs` call.
2230
2331
$(DOC_IMPLEMENTED_METHODS(":(LearnAPI.target)"; overloaded=true))
2432
@@ -32,10 +40,20 @@ Return, for each form of `data` supported in a call of the form [`fit(learner,
3240
data)`](@ref), the per-observation weights part of `data`. Where `nothing` is returned, no
3341
weights are part of `data`, which is to be interpreted as uniform weighting.
3442
43+
The returned object `w` has the same number of observations as `data`. If `data` is the
44+
output of an [`obs`](@ref) call, then `w` is additionally guaranteed to implement the
45+
data interface specified by [`LearnAPI.data_interface(learner)`](@ref).
46+
47+
# Extended help
48+
3549
# New implementations
3650
3751
Overloading is optional. A fallback returns `nothing`.
3852
53+
If overloading [`obs`](@ref), ensure that the return value, unless `nothing`, implements
54+
the data interface specified by [`LearnAPI.data_interface(learner)`](@ref), in the special
55+
case that `data` is the output of an `obs` call.
56+
3957
$(DOC_IMPLEMENTED_METHODS(":(LearnAPI.weights)"; overloaded=true))
4058
4159
"""
@@ -53,26 +71,34 @@ implemented, as in the following sample workflow:
5371
5472
```julia
5573
model = fit(learner, data)
56-
X = features(data)
74+
X = LearnAPI.features(learner, data)
5775
ŷ = predict(learner, kind_of_proxy, X) # eg, `kind_of_proxy = Point()`
5876
```
5977
60-
The returned object has the same number of observations as `data`. For supervised models
61-
(i.e., where `:(LearnAPI.target) in LearnAPI.functions(learner)`) `ŷ` above is generally
62-
intended to be an approximate proxy for `LearnAPI.target(learner, data)`, the training
63-
target.
78+
For supervised models (i.e., where `:(LearnAPI.target) in LearnAPI.functions(learner)`)
79+
`ŷ` above is generally intended to be an approximate proxy for `LearnAPI.target(learner,
80+
data)`, the training target.
81+
82+
The object `X` returned by `LearnAPI.target` has the same number of observations as
83+
`data`. If `data` is the output of an [`obs`](@ref) call, then `X` is additionally
84+
guaranteed to implement the data interface specified by
85+
[`LearnAPI.data_interface(learner)`](@ref).
6486
87+
# Extended help
6588
6689
# New implementations
6790
68-
That the output can be passed to `predict` and/or `transform`, and has the same number of
69-
observations as `data`, are the only contracts. A fallback returns `first(data)` if `data`
70-
is a tuple, and otherwise returns `data`.
91+
For density estimators, whose `fit` typically consumes *only* a target variable, you
92+
should overload this method to return `nothing`.
93+
94+
It must otherwise be possible to pass the return value `X` to `predict` and/or
95+
`transform`, and `X` must have same number of observations as `data`. A fallback returns
96+
`first(data)` if `data` is a tuple, and otherwise returns `data`.
7197
72-
Overloading may be necessary if [`obs(learner, data)`](@ref) is overloaded to return
73-
some learner-specific representation of training `data`. For density estimators, whose
74-
`fit` typically consumes *only* a target variable, you should overload this method to
75-
return `nothing`.
98+
Further overloadings may be necessary to handle the case that `data` is the output of
99+
[`obs(learner, data)`](@ref), if `obs` is being overloaded. In this case, be sure that
100+
`X`, unless `nothing`, implements the data interface specified by
101+
[`LearnAPI.data_interface(learner)`](@ref).
76102
77103
"""
78104
features(learner, data) = _first(data)

0 commit comments

Comments
 (0)