Skip to content

Commit 718cea8

Browse files
committed
make predict dispatch on kind of target proxy
oops update the docs tweaks TargetProxy -> KindOfProxy tweaks
1 parent dfc4551 commit 718cea8

10 files changed

+421
-409
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Anthony D. Blaom <[email protected]>"]
44
version = "0.1.0"
55

66
[deps]
7+
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
78
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
89

910
[extras]

docs/src/algorithm_traits.md

Lines changed: 47 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22

33
> **Summary.** Traits allow one to promise particular behaviour for an algorithm, such as:
44
> *This algorithm supports per-observation weights, which must appear as the third
5-
> *argument of > `fit`*, or *This algorithm predicts probability distributions for the
6-
> *target*, or *This > algorithm's `transform` method predicts `Real` vectors*.
5+
> argument of `fit`*, or *This algorithm's `transform` method predicts `Real` vectors*.
76
87
For any (non-trivial) algorithm, [`LearnAPI.functions`](@ref)`(algorithm)` must be
9-
overloaded to list the LearnAPI methods that have been explicitly overloaded (algorithm
10-
traits excluded). Otherwise, overloading traits is optional, except where required by the
11-
implementation of some LearnAPI method and explicitly documented in that method's
8+
overloaded to list the LearnAPI methods that have been explicitly implemented/overloaded
9+
(algorithm traits excluded). Overloading other traits is optional, except where required
10+
by the implementation of some LearnAPI method and explicitly documented in that method's
1211
docstring.
1312

1413
Traits are often called on instances but are usually *defined* on algorithm *types*, as in
@@ -32,30 +31,41 @@ t(algorithm) = t(typeof(algorithm))
3231
This means `LearnAPI.is_pure_julia(algorithm) = true` whenever `algorithm isa MyAlgorithmType` in the
3332
above example.
3433

35-
Every trait has a global fallback implementation for `::Type`.
34+
Every trait has a global fallback implementation for `::Type`. See the table below.
35+
36+
## When traits depdend on more than algorithm type
3637

3738
Traits that vary from instance to instance of the same type are disallowed, except in the
38-
case of composite algorithms (`is_wrapper(algorithm) = true`) where this is unavoidable. (One
39-
reason for this is so one can associate with each algorithm type a unique set of trait-based
40-
"algorithm metadata" for inclusion in searchable algorithm databases.) This requirement
41-
occasionally requires that an existing algorithm implementation be split into separate
42-
LearnAPI implementations (e.g., one for regression and another for classification).
39+
case of composite algorithms (`is_wrapper(algorithm) = true`) where this is typically
40+
unavoidable. The reason for this is so one can associate, with each non-composite
41+
algorithm type, unique trait-based "algorithm metadata", for inclusion in searchable
42+
algorithm databases. This requirement occasionally requires that an existing algorithm
43+
implementation be split into separate LearnAPI implementations (e.g., one for regression
44+
and another for classification).
45+
46+
## Special two-argument traits
47+
48+
The two-argument version of [`LearnAPI.predict_output_scitype`](@ref) and
49+
[`LearnAPI.predict_output_scitype`](@ref) are the only overloadable traits with more than
50+
one argument. They cannot be declared using the `@trait` macro.
4351

44-
**Ordinary traits** are available for overloading by any new LearnAPI implementation. **Derived
45-
traits** are not.
52+
## Trait summary
4653

47-
## Ordinary traits
54+
**Overloadable traits** are available for overloading by any new LearnAPI
55+
implementation. **Derived traits** are not, and should not be called by performance
56+
critical code
57+
58+
## Overloadable traits
4859

4960
In the examples column of the table below, `Table`, `Continuous`, `Sampleable` are names owned by the
5061
package [ScientificTypesBase.jl](https://github.com/JuliaAI/ScientificTypesBase.jl/).
5162

5263
| trait | fallback value | return value | example |
5364
|:-------------------------------------------------|:----------------------|:--------------|:--------|
5465
| [`LearnAPI.functions`](@ref)`(algorithm)` | `()` | implemented LearnAPI functions (traits excluded) | `(:fit, :predict)` |
55-
| [`LearnAPI.predict_proxy`](@ref)`(algorithm)` | `LearnAPI.None()` | form of target proxy output by `predict` | `LearnAPI.Distribution()` |
56-
| [`LearnAPI.predict_joint_proxy`](@ref)`(algorithm)` | `LearnAPI.None()` | form of target proxy output by `predict_joint` | `LearnAPI.Distribution()` |
57-
| [`LearnAPI.position_of_target`](@ref)`(algorithm)` | `0` | † the positional index of the **target** in `data` in `fit(..., data...; metadata)` calls | 2 |
58-
| [`LearnAPI.position_of_weights`](@ref)`(algorithm)` | `0` | † the positional index of **per-observation weights** in `data` in `fit(..., data...; metadata)` | 3 |
66+
| [`LearnAPI.preferred_kind_of_proxy`](@ref)`(algorithm)` | `LearnAPI.None()` | an instance `tp` of `KindOfProxy` for which an implementation of `LearnAPI.predict(algorithm, tp, ...)` is guaranteed. | `LearnAPI.Distribution()` |
67+
| [`LearnAPI.position_of_target`](@ref)`(algorithm)` | `0` | ¹ the positional index of the **target** in `data` in `fit(..., data...; metadata)` calls | 2 |
68+
| [`LearnAPI.position_of_weights`](@ref)`(algorithm)` | `0` | ¹ the positional index of **per-observation weights** in `data` in `fit(..., data...; metadata)` | 3 |
5969
| [`LearnAPI.descriptors`](@ref)`(algorithm)` | `()` | lists one or more suggestive algorithm descriptors from `LearnAPI.descriptors()` | (:classifier, :probabilistic) |
6070
| [`LearnAPI.is_pure_julia`](@ref)`(algorithm)` | `false` | is `true` if implementation is 100% Julia code | `true` |
6171
| [`LearnAPI.pkg_name`](@ref)`(algorithm)` | `"unknown"` | name of package providing core code (may be different from package providing LearnAPI.jl implementation) | `"DecisionTree"` |
@@ -66,54 +76,47 @@ package [ScientificTypesBase.jl](https://github.com/JuliaAI/ScientificTypesBase.
6676
| [`LearnAPI.human_name`](@ref)`(algorithm)` | type name with spaces | human name for the algorithm; should be a noun | "elastic net regressor" |
6777
| [`LearnAPI.iteration_parameter`](@ref)`(algorithm)` | `nothing` | symbolic name of an iteration parameter | :epochs |
6878
| [`LearnAPI.fit_keywords`](@ref)`(algorithm)` | `()` | tuple of symbols for keyword arguments accepted by `fit` (corresponding to metadata) | `(:class_weights,)` |
69-
| [`LearnAPI.fit_scitype`](@ref)`(algorithm)` | `Union{}` | upper bound on `scitype(data)` in `fit(algorithm, verbosity, data...)`†† | `Tuple{Table(Continuous), AbstractVector{Continuous}}` |
70-
| [`LearnAPI.fit_observation_scitype`](@ref)`(algorithm)` | `Union{}`| upper bound on `scitype(observation)` for `observation` in `data` and `data` in `fit(algorithm, verbosity, data...)`†† | `Tuple{AbstractVector{Continuous}, Continuous}` |
71-
| [`LearnAPI.fit_type`](@ref)`(algorithm)` | `Union{}` | upper bound on `type(data)` in `fit(algorithm, verbosity, data...)`†† | `Tuple{AbstractMatrix{<:Real}, AbstractVector{<:Real}}` |
79+
| [`LearnAPI.fit_scitype`](@ref)`(algorithm)` | `Union{}` | upper bound on `scitype(data)` in `fit(algorithm, verbosity, data...)`² | `Tuple{Table(Continuous), AbstractVector{Continuous}}` |
80+
| [`LearnAPI.fit_observation_scitype`](@ref)`(algorithm)` | `Union{}`| upper bound on `scitype(observation)` for `observation` in `data` and `data` in `fit(algorithm, verbosity, data...)`² | `Tuple{AbstractVector{Continuous}, Continuous}` |
81+
| [`LearnAPI.fit_type`](@ref)`(algorithm)` | `Union{}` | upper bound on `type(data)` in `fit(algorithm, verbosity, data...)`² | `Tuple{AbstractMatrix{<:Real}, AbstractVector{<:Real}}` |
7282
| [`LearnAPI.fit_observation_type`](@ref)`(algorithm)` | `Union{}`| upper bound on `type(observation)` for `observation` in `data` and `data` in `fit(algorithm, verbosity, data...)`* | `Tuple{AbstractVector{<:Real}, Real}` |
73-
| [`LearnAPI.predict_input_scitype`](@ref)`(algorithm)` | `Union{}` | upper bound on `scitype(data)` in `predict(algorithm, fitted_params, data...)`†† | `Table(Continuous)` |
74-
| [`LearnAPI.predict_output_scitype`](@ref)`(algorithm)` | `Any` | upper bound on `scitype(first(predict(algorithm, ...)))` | `AbstractVector{Continuous}` |
75-
| [`LearnAPI.predict_input_type`](@ref)`(algorithm)` | `Union{}` | upper bound on `typeof(data)` in `predict(algorithm, fitted_params, data...)`†† | `AbstractMatrix{<:Real}` |
76-
| [`LearnAPI.predict_output_type`](@ref)`(algorithm)` | `Any` | upper bound on `typeof(first(predict(algorithm, ...)))` | `AbstractVector{<:Real}` |
77-
| [`LearnAPI.predict_joint_input_scitype`](@ref)`(algorithm)` | `Union{}` | upper bound on `scitype(data)` in `predict_joint(algorithm, fitted_params, data...)`†† |`Table(Continuous)` |
78-
| [`LearnAPI.predict_joint_output_scitype`](@ref)`(algorithm)` | `Any` | upper bound on `scitype(first(predict_joint(algorithm, ...)))` | `Sampleable{<:AbstractVector{Continuous}}` |
79-
| [`LearnAPI.predict_joint_input_type`](@ref)`(algorithm)` | `Union{}` | upper bound on `typeof(data)` in `predict_joint(algorithm, fitted_params, data...)`†† | `AbstractMatrix{<:Real}` |
80-
| [`LearnAPI.predict_joint_output_type`](@ref)`(algorithm)` | `Any` | upper bound on `typeof(first(predict_joint(algorithm, ...)))` | `Distributions.Sampleable{Distributions.Multivariate,Distributions.Continuous}` |
81-
| [`LearnAPI.transform_input_scitype`](@ref)`(algorithm)` | `Union{}` | upper bound on `scitype(data)` in `transform(algorithm, fitted_params, data...)`†† | `Table(Continuous)` |
83+
| [`LearnAPI.predict_input_scitype`](@ref)`(algorithm)` | `Union{}` | upper bound on `scitype(data)` in `predict(algorithm, fitted_params, data...)`² | `Table(Continuous)` |
84+
| [`LearnAPI.predict_output_scitype`](@ref)`(algorithm, kind_of_proxy)` | `Any` | upper bound on `scitype(first(predict(algorithm, kind_of_proxy, ...)))` | `AbstractVector{Continuous}` |
85+
| [`LearnAPI.predict_input_type`](@ref)`(algorithm)` | `Union{}` | upper bound on `typeof(data)` in `predict(algorithm, fitted_params, data...)`² | `AbstractMatrix{<:Real}` |
86+
| [`LearnAPI.predict_output_type`](@ref)`(algorithm, kind_of_proxy)` | `Any` | upper bound on `typeof(first(predict(algorithm, kind_of_proxy, ...)))` | `AbstractVector{<:Real}` |
87+
| [`LearnAPI.transform_input_scitype`](@ref)`(algorithm)` | `Union{}` | upper bound on `scitype(data)` in `transform(algorithm, fitted_params, data...)`² | `Table(Continuous)` |
8288
| [`LearnAPI.transform_output_scitype`](@ref)`(algorithm)` | `Any` | upper bound on `scitype(first(transform(algorithm, ...)))` | `Table(Continuous)` |
83-
| [`LearnAPI.transform_input_type`](@ref)`(algorithm)` | `Union{}` | upper bound on `typeof(data)` in `transform(algorithm, fitted_params, data...)`†† | `AbstractMatrix{<:Real}}` |
89+
| [`LearnAPI.transform_input_type`](@ref)`(algorithm)` | `Union{}` | upper bound on `typeof(data)` in `transform(algorithm, fitted_params, data...)`² | `AbstractMatrix{<:Real}}` |
8490
| [`LearnAPI.transform_output_type`](@ref)`(algorithm)` | `Any` | upper bound on `typeof(first(transform(algorithm, ...)))` | `AbstractMatrix{<:Real}` |
85-
| [`LearnAPI.inverse_transform_input_scitype`](@ref)`(algorithm)` | `Union{}` | upper bound on `scitype(data)` in `inverse_transform(algorithm, fitted_params, data...)`†† | `Table(Continuous)` |
86-
| [`LearnAPI.inverse_transform_output_scitype`](@ref)`(algorithm)` | `Any` | upper bound on `scitype(first(inverse_transform(algorithm, ...)))` | `Table(Continuous)` |
87-
| [`LearnAPI.inverse_transform_input_type`](@ref)`(algorithm)` | `Union{}` | upper bound on `typeof(data)` in `inverse_transform(algorithm, fitted_params, data...)`†† | `AbstractMatrix{<:Real}` |
88-
| [`LearnAPI.inverse_transform_output_type`](@ref)`(algorithm)` | `Any` | upper bound on `typeof(first(inverse_transform(algorithm, ...)))` | `AbstractMatrix{<:Real}` |
89-
9091

91-
If the value is `0`, then the variable in boldface type is not supported and not
92+
¹ If the value is `0`, then the variable in boldface type is not supported and not
9293
expected to appear in `data`. If `length(data)` is less than the trait value, then `data`
9394
is understood to exclude the variable, but note that `fit` can have multiple signatures of
9495
varying lengths, as in `fit(algorithm, verbosity, X, y)` and `fit(algorithm, verbosity, X, y,
9596
w)`. A non-zero value is a promise that `fit` includes a signature of sufficient length to
9697
include the variable.
9798

98-
†† Assuming no [optional data interface](@ref data_interface) is implemented. See docstring
99+
² Assuming no [optional data interface](@ref data_interface) is implemented. See docstring
99100
for the general case.
100101

101102

102103
## Derived Traits
103104

104105
The following convenience methods are provided but intended for overloading:
105106

106-
| trait | return value | example |
107-
|:-----------------------------|:------------------------------------------|:--------|
108-
| `LearnAPI.name(algorithm)` | algorithm type name as string | "PCA" |
109-
| `LearnAPI.is_algorithm(algorithm)` | `true` if `functions(algorithm)` is not empty | `true` |
107+
| trait | return value | example |
108+
|:-------------------------------------|:------------------------------------------|:-----------|
109+
| `LearnAPI.name(algorithm)` | algorithm type name as string | "PCA" |
110+
| `LearnAPI.is_algorithm(algorithm)` | `true` if `functions(algorithm)` is not empty | `true` |
111+
| [`LearnAPI.predict_output_scitype`](@ref)(algorithm) | dictionary of upper bounds on the scitype of predictions, keyed on subtypes of [`LearnAPI.KindOfProxy`](@ref) |
112+
| [`LearnAPI.predict_output_type`](@ref)(algorithm) | dictionary of upper bounds on the type of predictions, keyed on subtypes of [`LearnAPI.KindOfProxy`](@ref) |
113+
110114

111115
## Reference
112116

113117
```@docs
114118
LearnAPI.functions
115-
LearnAPI.predict_proxy
116-
LearnAPI.predict_joint_proxy
119+
LearnAPI.preferred_kind_of_proxy
117120
LearnAPI.position_of_target
118121
LearnAPI.position_of_weights
119122
LearnAPI.descriptors
@@ -134,16 +137,8 @@ LearnAPI.predict_input_scitype
134137
LearnAPI.predict_output_scitype
135138
LearnAPI.predict_input_type
136139
LearnAPI.predict_output_type
137-
LearnAPI.predict_joint_input_scitype
138-
LearnAPI.predict_joint_output_scitype
139-
LearnAPI.predict_joint_input_type
140-
LearnAPI.predict_joint_output_type
141140
LearnAPI.transform_input_scitype
142141
LearnAPI.transform_output_scitype
143142
LearnAPI.transform_input_type
144143
LearnAPI.transform_output_type
145-
LearnAPI.inverse_transform_input_scitype
146-
LearnAPI.inverse_transform_output_scitype
147-
LearnAPI.inverse_transform_input_type
148-
LearnAPI.inverse_transform_output_type
149144
```

0 commit comments

Comments
 (0)