Skip to content

Commit 72009e2

Browse files
committed
tweaks and corrections
1 parent db2f287 commit 72009e2

File tree

4 files changed

+16
-21
lines changed

4 files changed

+16
-21
lines changed

docs/src/anatomy_of_an_implementation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ specified by the trait* [`LearnAPI.data_interface(algorithm)`](@ref). Assuming t
411411
```@example anatomy2
412412
Base.getindex(data::RidgeFitObs, I) =
413413
RidgeFitObs(data.A[:,I], data.names, y[I])
414-
Base.length(data::RidgeFitObs, I) = length(data.y)
414+
Base.length(data::RidgeFitObs) = length(data.y)
415415
```
416416

417417
We can do something similar for `predict`, but there's no need for a new type in this

docs/src/target_weights_features.md

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,11 @@ training_loss = sum(ŷ .!= y)
2828

2929
# Implementation guide
3030

31-
The fallback returns `first(data)`, assuming `data` is a tuple, and `data` otherwise.
32-
33-
| method | fallback | compulsory? |
34-
|:----------------------------|:-----------------:|------------------------|
35-
| [`LearnAPI.target`](@ref) | returns `nothing` | no |
36-
| [`LearnAPI.weights`](@ref) | returns `nothing` | no |
37-
| [`LearnAPI.features`](@ref) | see docstring | only if fallback fails |
31+
| method | fallback | compulsory? |
32+
|:----------------------------|:-----------------:|--------------------------|
33+
| [`LearnAPI.target`](@ref) | returns `nothing` | no |
34+
| [`LearnAPI.weights`](@ref) | returns `nothing` | no |
35+
| [`LearnAPI.features`](@ref) | see docstring | if fallback insufficient |
3836

3937

4038
# Reference

src/target_weights_features.jl

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ weights(::Any, data) = nothing
3838
3939
Return, for each form of `data` supported in a call of the form [`fit(algorithm,
4040
data)`](@ref), the "features" part of `data` (as opposed to the target
41-
variable, for example).
41+
variable, for example).
4242
4343
The returned object `X` may always be passed to `predict` or `transform`, where
4444
implemented, as in the following sample workflow:
@@ -49,28 +49,25 @@ X = features(data)
4949
ŷ = predict(algorithm, kind_of_proxy, X) # eg, `kind_of_proxy = Point()`
5050
```
5151
52-
The return value has the same number of observations as `data` does. For supervised models
52+
The returned object has the same number of observations as `data`. For supervised models
5353
(i.e., where `:(LearnAPI.target) in LearnAPI.functions(algorithm)`) `ŷ` above is generally
5454
intended to be an approximate proxy for `LearnAPI.target(algorithm, data)`, the training
5555
target.
5656
5757
5858
# New implementations
5959
60-
The only contract `features` must satisfy is the one about passability of the output to
61-
`predict` or `transform`, for each supported input `data`. The following fallbacks
62-
typically make overloading `LearnAPI.features` unnecessary:
63-
64-
```julia
65-
LearnAPI.features(algorithm, data) = data
66-
LearnAPI.features(algorithm, data::Tuple) = first(data)
67-
```
60+
That the output can be passed to `predict` and/or `transform`, and has the same number of
61+
observations as `data`, are the only contracts. A fallback returns `first(data)` if `data`
62+
is a tuple, and otherwise returns `data`.
6863
6964
Overloading may be necessary if [`obs(algorithm, data)`](@ref) is overloaded to return
7065
some algorithm-specific representation of training `data`. For density estimators, whose
7166
`fit` typically consumes *only* a target variable, you should overload this method to
7267
return `nothing`.
7368
7469
"""
75-
features(algorithm, data) = data
76-
features(algorithm, data::Tuple) = first(data)
70+
features(algorithm, data) = _first(data)
71+
_first(data) = data
72+
_first(data::Tuple) = first(data)
73+
# note the factoring above guards agains method ambiguities

test/integration/regression.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ LearnAPI.algorithm(model::RidgeFitted) = model.algorithm
3939

4040
Base.getindex(data::RidgeFitObs, I) =
4141
RidgeFitObs(data.A[:,I], data.names, data.y[I])
42-
Base.length(data::RidgeFitObs, I) = length(data.y)
42+
Base.length(data::RidgeFitObs) = length(data.y)
4343

4444
# observations for consumption by `fit`:
4545
function LearnAPI.obs(::Ridge, data)

0 commit comments

Comments
 (0)