Skip to content

Commit 2b11e6b

Browse files
committed
clarify need for obs to be involutive
1 parent 6f436ef commit 2b11e6b

File tree

2 files changed

+37
-13
lines changed

2 files changed

+37
-13
lines changed

docs/src/anatomy_of_an_implementation.md

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -420,10 +420,21 @@ LearnAPI.fit(learner::Ridge, data; kwargs...) =
420420

421421
### The `obs` contract
422422

423-
Providing `fit` signatures matching the output of `obs`, is the first part of the `obs`
424-
contract. The second part is this: *The output of `obs` must implement the interface
425-
specified by the trait* [`LearnAPI.data_interface(learner)`](@ref). Assuming this is
426-
[`LearnAPI.RandomAccess()`](@ref) (the default) it usually suffices to overload
423+
Providing `fit` signatures matching the output of [`obs`](@ref), is the first part of the
424+
`obs` contract. Since `obs(learner, data)` should evidentally support all `data` that
425+
`fit(learner, data)` supports, we must be able to apply `obs(learner, _)` to it's own
426+
output (`observations` below). This leads to the additional "no-op" declaration
427+
428+
```@example anatomy2
429+
LearnAPI.obs(::Ridge, observations::RidgeFitObs) = observations
430+
```
431+
432+
In other words, we ensure that `obs(learner, _)` is
433+
[involutive](https://en.wikipedia.org/wiki/Involution_(mathematics)).
434+
435+
The second part of the `obs` contract is this: *The output of `obs` must implement the
436+
interface specified by the trait* [`LearnAPI.data_interface(learner)`](@ref). Assuming
437+
this is [`LearnAPI.RandomAccess()`](@ref) (the default) it usually suffices to overload
427438
`Base.getindex` and `Base.length`:
428439

429440
```@example anatomy2
@@ -432,11 +443,11 @@ Base.getindex(data::RidgeFitObs, I) =
432443
Base.length(data::RidgeFitObs) = length(data.y)
433444
```
434445

435-
We can do something similar for `predict`, but there's no need for a new type in this
436-
case:
446+
We do something similar for `predict`, but there's no need for a new type in this case:
437447

438448
```@example anatomy2
439449
LearnAPI.obs(::RidgeFitted, Xnew) = Tables.matrix(Xnew)'
450+
LearnAPI.obs(::RidgeFitted, observations::AbstractArray) = observations # involutivity
440451
441452
LearnAPI.predict(model::RidgeFitted, ::Point, observations::AbstractMatrix) =
442453
observations'*model.coefficients

src/obs.jl

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,19 @@ For each supported form of `data` in `fit(learner, data)`, it must be true that
5454
fit(learner, observations)` is equivalent to `model = fit(learner, data)`, whenever
5555
`observations = obs(learner, data)`. For each supported form of `data` in calls
5656
`predict(model, ..., data)` and `transform(model, data)`, where implemented, the calls
57-
`predict(model, ..., observations)` and `transform(model, observations)` are supported
58-
alternatives, whenever `observations = obs(model, data)`.
57+
`predict(model, ..., observations)` and `transform(model, observations)` must be supported
58+
alternatives with the same output, whenever `observations = obs(model, data)`.
59+
60+
Implicit in the above requirements is that `obs(learner, _)` and `obs(model, _)` are
61+
involutive, meaning both the following hold:
62+
63+
```julia
64+
obs(learner, obs(learner, data)) == obs(learner, data)
65+
obs(model, obs(model, data) == obs(model, obs(model, data)
66+
```
67+
68+
If one overloads `obs`, one typically needs additionally overloadings to guarantee
69+
involutivity.
5970
6071
The fallback for `obs` is `obs(model_or_learner, data) = data`, and the fallback for
6172
`LearnAPI.data_interface(learner)` is `LearnAPI.RandomAccess()`. For details refer to
@@ -67,14 +78,16 @@ to be overloaded. However, the user will get no performance benefits by using `o
6778
that case.
6879
6980
When overloading `obs(learner, data)` to output new model-specific representations of
70-
data, it may be necessary to also overload [`LearnAPI.features`](@ref),
71-
[`LearnAPI.target`](@ref) (supervised learners), and/or [`LearnAPI.weights`](@ref) (if
72-
weights are supported), for extracting relevant parts of the representation.
81+
data, it may be necessary to also overload [`LearnAPI.features(learner,
82+
observations)`](@ref), [`LearnAPI.target(learner, observations)`](@ref) (supervised
83+
learners), and/or [`LearnAPI.weights(learner, observations)`](@ref) (if weights are
84+
supported), for each kind output `observations` of `obs(learner, data)`.
7385
7486
## Sample implementation
7587
76-
Refer to the "Anatomy of an Implementation" section of the LearnAPI.jl
77-
[manual](https://juliaai.github.io/LearnAPI.jl/dev/).
88+
Refer to the ["Anatomy of an
89+
Implementation"](https://juliaai.github.io/LearnAPI.jl/dev/anatomy_of_an_implementation/#Providing-an-advanced-data-interface)
90+
section of the LearnAPI.jl manual.
7891
7992
8093
"""

0 commit comments

Comments
 (0)