Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions docs/src/anatomy_of_an_implementation.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
# Anatomy of an Implementation

This section explains a detailed implementation of the LearnAPI for naive [ridge
This section explains a detailed implementation of the LearnAPI.jl for naive [ridge
regression](https://en.wikipedia.org/wiki/Ridge_regression) with no intercept. The kind of
workflow we want to enable has been previewed in [Sample workflow](@ref). Readers can also
refer to the [demonstration](@ref workflow) of the implementation given later.

A transformer ordinarily implements `transform` instead of
`predict`. For more on `predict` versus `transform`, see [Predict or transform?](@ref)
The core LearnAPI.jl pattern looks like this:

```julia
model = fit(algorithm, data)
predict(model, newdata)
```

A transformer ordinarily implements `transform` instead of `predict`. For more on
`predict` versus `transform`, see [Predict or transform?](@ref)

!!! note

New implementations of `fit`, `predict`, etc,
always have a *single* `data` argument, as in
`LearnAPI.fit(algorithm, data; verbosity=1) = ...`.
For convenience, user-calls, such as `fit(algorithm, X, y)`, automatically fallback
to `fit(algorithm, (X, y))`.
always have a *single* `data` argument as above.
For convenience, a signature such as `fit(algorithm, X, y)`, calling
`fit(algorithm, (X, y))`, can be added, but the LearnAPI.jl specification is
silent on the meaning or existence of signatures with extra arguments.

!!! note

Expand Down Expand Up @@ -52,7 +59,7 @@ nothing # hide

Instances of `Ridge` will be [algorithms](@ref algorithms), in LearnAPI.jl parlance.

Associated with each new type of LearnAPI [algorithm](@ref algorithms) will be a keyword
Associated with each new type of LearnAPI.jl [algorithm](@ref algorithms) will be a keyword
argument constructor, providing default values for all properties (struct fields) that are
not other algorithms, and we must implement [`LearnAPI.constructor(algorithm)`](@ref), for
recovering the constructor from an instance:
Expand Down Expand Up @@ -244,6 +251,14 @@ in LearnAPI.functions(algorithm)`, for every instance `algorithm`. With [some
exceptions](@ref trait_contract), the value of a trait should depend only on the *type* of
the argument.

## Signatures added for convenience

We add one `fit` signature for user-convenience only. The LearnAPI.jl specification has
nothing to say about `fit` signatures with more than two positional arguments.

```@example anatomy
LearnAPI.fit(algorithm::Ridge, X, y; kwargs...) = fit(algorithm, (X, y); kwargs...)
```

## [Demonstration](@id workflow)

Expand Down Expand Up @@ -466,6 +481,14 @@ overload the trait, [`LearnAPI.data_interface(algorithm)`](@ref). See [Data
interfaces](@ref data_interfaces) for details.


### Addition of signatures for user convenience

As above, we add a signature which plays no role vis-à-vis LearnAPI.jl.

```@example anatomy2
LearnAPI.fit(algorithm::Ridge, X, y; kwargs...) = fit(algorithm, (X, y); kwargs...)
```

## Demonstration of an advanced `obs` workflow

We now can train and predict using internal data representations, resampled using the
Expand Down
2 changes: 2 additions & 0 deletions docs/src/common_implementation_patterns.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Common Implementation Patterns

!!! warning

!!! warning

This section is only an implementation guide. The definitive specification of the
Expand Down
20 changes: 17 additions & 3 deletions docs/src/fit_update.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ A "static" algorithm is one that does not generalize to new observations (e.g.,
clustering algorithms); there is no trainiing data and the algorithm is executed by
`predict` or `transform` which receive the data. See example below.

When `fit` expects a tuple form of argument, `data = (X1, ..., Xn)`, then the signature
`fit(algorithm, X1, ..., Xn)` is also provided.

### Updating

Expand All @@ -32,7 +30,7 @@ Supposing `Algorithm` is some supervised classifier type, with an iteration para

```julia
algorithm = Algorithm(n=100)
model = fit(algorithm, (X, y)) # or `fit(algorithm, X, y)`
model = fit(algorithm, (X, y))

# Predict probability distributions:
ŷ = predict(model, Distribution(), Xnew)
Expand Down Expand Up @@ -76,6 +74,22 @@ labels = predict(algorithm, X)
LearnAPI.extras(model)
```

### Density estimation

In density estimation, `fit` consumes no features, only a target variable; `predict`,
which consumes no data, returns the learned density:

```julia
model = fit(algorithm, y) # no features
predict(model) # shortcut for `predict(model, Distribution())`
```

A one-liner will typically be implemented as well:

```julia
predict(algorithm, y)
```

## Implementation guide

### Training
Expand Down
7 changes: 3 additions & 4 deletions docs/src/predict_transform.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@ transform(model, data)
inverse_transform(model, data)
```

When a method expects a tuple form of argument, `data = (X1, ..., Xn)`, then a slurping
signature is also provided, as in `transform(model, X1, ..., Xn)`.

Versions without the `data` argument may also appear, for example in [Density
estimation](@ref).

## [Typical worklows](@id predict_workflow)

Train some supervised `algorithm`:

```julia
model = fit(algorithm, (X, y)) # or `fit(algorithm, X, y)`
model = fit(algorithm, (X, y))
```

Predict probability distributions:
Expand Down
37 changes: 15 additions & 22 deletions src/fit_update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,10 @@ The second signature is provided by algorithms that do not generalize to new obs
..., data)` carries out the actual algorithm execution, writing any byproducts of that
operation to the mutable object `model` returned by `fit`.

Whenever `fit` expects a tuple form of argument, `data = (X1, ..., Xn)`, then the
signature `fit(algorithm, X1, ..., Xn)` is also provided.

For example, a supervised classifier will typically admit this workflow:
For example, a supervised classifier might have a workflow like this:

```julia
model = fit(algorithm, (X, y)) # or `fit(algorithm, X, y)`
model = fit(algorithm, (X, y))
ŷ = predict(model, Xnew)
```

Expand All @@ -33,24 +30,22 @@ See also [`predict`](@ref), [`transform`](@ref), [`inverse_transform`](@ref),

# New implementations

Implementation is compulsory. The signature must include `verbosity`. Fallbacks provide
the data slurping versions. A fallback for the first signature calls the second, ignoring
`data`:
Implementation of exactly one of the signatures is compulsory. If `fit(algorithm;
verbosity=1)` is implemented, then the trait [`LearnAPI.is_static`](@ref) must be
overloaded to return `true`.

```julia
fit(algorithm, data; kwargs...) = fit(algorithm; kwargs...)
```
The signature must include `verbosity`.

If only the `fit(algorithm)` signature is expliclty implemented, then the trait
[`LearnAPI.is_static`](@ref) must be overloaded to return `true`.
The LearnAPI.jl specification has nothing to say regarding `fit` signatures with more than
two arguments. For convenience, for example, an algorithm is free to implement a slurping
signature, such as `fit(algorithm, X, y, extras...) = fit(algorithm, (X, y, extras...))` but
LearnAPI.jl does not guarantee such signatures are actually implemented.

$(DOC_DATA_INTERFACE(:fit))

"""
fit(algorithm, data; kwargs...) =
fit(algorithm; kwargs...)
fit(algorithm, data1, datas...; kwargs...) =
fit(algorithm, (data1, datas...); kwargs...)
function fit end


# # UPDATE AND COUSINS

Expand Down Expand Up @@ -91,7 +86,7 @@ Implementation is optional. The signature must include
See also [`LearnAPI.clone`](@ref)

"""
update(model, data1, datas...; kwargs...) = update(model, (data1, datas...); kwargs...)
function update end

"""
update_observations(model, new_data; verbosity=1, parameter_replacements...)
Expand Down Expand Up @@ -127,8 +122,7 @@ Implementation is optional. The signature must include
See also [`LearnAPI.clone`](@ref).

"""
update_observations(algorithm, data1, datas...; kwargs...) =
update_observations(algorithm, (data1, datas...); kwargs...)
function update_observations end

"""
update_features(model, new_data; verbosity=1, parameter_replacements...)
Expand All @@ -154,5 +148,4 @@ Implementation is optional. The signature must include
See also [`LearnAPI.clone`](@ref).

"""
update_features(algorithm, data1, datas...; kwargs...) =
update_features(algorithm, (data1, datas...); kwargs...)
function update_features end
49 changes: 24 additions & 25 deletions src/predict_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@ DOC_MUTATION(op) =

"""

DOC_SLURPING(op) =
"""

An algorithm is free to implement `$op` signatures with additional positional
arguments (eg., data-slurping signatures) but LearnAPI.jl is silent about their
interpretation or existence.

"""

DOC_MINIMIZE(func) =
"""

If, additionally, [`LearnAPI.strip(model)`](@ref) is overloaded, then the following identity
must hold:
If, additionally, [`LearnAPI.strip(model)`](@ref) is overloaded, then the following
identity must hold:

```julia
$func(LearnAPI.strip(model), args...) = $func(model, args...)
Expand Down Expand Up @@ -63,7 +71,7 @@ which lists all supported target proxies.
The argument `model` is anything returned by a call of the form `fit(algorithm, ...)`.

If `LearnAPI.features(LearnAPI.algorithm(model)) == nothing`, then argument `data` is
omitted. An example is density estimators.
omitted in both signatures. An example is density estimators.

# Example

Expand All @@ -79,20 +87,20 @@ See also [`fit`](@ref), [`transform`](@ref), [`inverse_transform`](@ref).

# Extended help

If `predict` supports data in the form of a tuple `data = (X1, ..., Xn)`, then a slurping
signature is also provided, as in `predict(model, X1, ..., Xn)`.

Note `predict ` does not mutate any argument, except in the special case
Note `predict ` must not mutate any argument, except in the special case
`LearnAPI.is_static(algorithm) == true`.

# New implementations

If there is no notion of a "target" variable in the LearnAPI.jl sense, or you need an
operation with an inverse, implement [`transform`](@ref) instead.

Implementation is optional. Only the first signature is implemented, but each
`kind_of_proxy` that gets an implementation must be added to the list returned by
[`LearnAPI.kinds_of_proxy`](@ref).
Implementation is optional. Only the first signature (with or without the `data` argument)
is implemented, but each `kind_of_proxy` that gets an implementation must be added to the
list returned by [`LearnAPI.kinds_of_proxy`](@ref).

If `data` is not present in the implemented signature (eg., for density estimators) then
[`LearnAPI.features(algorithm, data)`](@ref) must return `nothing`.

$(DOC_IMPLEMENTED_METHODS(":(LearnAPI.predict)"))

Expand All @@ -106,23 +114,12 @@ $(DOC_DATA_INTERFACE(:predict))
predict(model, data) = predict(model, kinds_of_proxy(algorithm(model)) |> first, data)
predict(model) = predict(model, kinds_of_proxy(algorithm(model)) |> first)

# automatic slurping of multiple data arguments:
predict(model, k::KindOfProxy, data1, data2, datas...; kwargs...) =
predict(model, k, (data1, data2, datas...); kwargs...)
predict(model, data1, data2, datas...; kwargs...) =
predict(model, (data1, data2, datas...); kwargs...)



"""
transform(model, data)

Return a transformation of some `data`, using some `model`, as returned by
[`fit`](@ref).

For `data` that consists of a tuple, a slurping version is also provided, i.e., you can do
`transform(model, X1, X2, X3)` in place of `transform(model, (X1, X2, X3))`.

# Example

Below, `X` and `Xnew` are data of the same form.
Expand Down Expand Up @@ -157,8 +154,10 @@ See also [`fit`](@ref), [`predict`](@ref),

# New implementations

Implementation for new LearnAPI.jl algorithms is optional. A fallback provides the
slurping version. $(DOC_IMPLEMENTED_METHODS(":(LearnAPI.transform)"))
Implementation for new LearnAPI.jl algorithms is
optional. $(DOC_IMPLEMENTED_METHODS(":(LearnAPI.transform)"))

$(DOC_SLURPING(:transform))

$(DOC_MINIMIZE(:transform))

Expand All @@ -167,8 +166,8 @@ $(DOC_MUTATION(:transform))
$(DOC_DATA_INTERFACE(:transform))

"""
transform(model, data1, data2, datas...; kwargs...) =
transform(model, (data1, data2, datas...); kwargs...) # automatic slurping
function transform end


"""
inverse_transform(model, data)
Expand Down
14 changes: 0 additions & 14 deletions test/fit_update.jl

This file was deleted.

7 changes: 7 additions & 0 deletions test/patterns/ensembling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ LearnAPI.strip(model::EnsembleFitted) = EnsembleFitted(
)
)

# convenience method:
LearnAPI.fit(algorithm::Ensemble, X, y, extras...; kwargs...) =
fit(algorithm, (X, y, extras...); kwargs...)
LearnAPI.update(algorithm::EnsembleFitted, X, y, extras...; kwargs...) =
update(algorithm, (X, y, extras...); kwargs...)


# synthetic test data:
N = 10 # number of observations
train = 1:6
Expand Down
16 changes: 13 additions & 3 deletions test/patterns/gradient_descent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,9 @@ function LearnAPI.update_observations(
)

# unpack data:
X = observations.X
y_hot = observations.y_hot
classes = observations.classes
X = observations_new.X
y_hot = observations_new.y_hot
classes = observations_new.classes
nclasses = length(classes)

classes == model.classes || error("New training target has incompatible classes.")
Expand Down Expand Up @@ -328,6 +328,16 @@ LearnAPI.training_losses(model::PerceptronClassifierFitted) = model.losses
)


# ### Convenience methods

LearnAPI.fit(algorithm::PerceptronClassifier, X, y; kwargs...) =
fit(algorithm, (X, y); kwargs...)
LearnAPI.update_observations(algorithm::PerceptronClassifier, X, y; kwargs...) =
update_observations(algorithm, (X, y); kwargs...)
LearnAPI.update(algorithm::PerceptronClassifier, X, y; kwargs...) =
update(algorithm, (X, y); kwargs...)


# ## Tests

# synthetic test data:
Expand Down
Loading
Loading