Skip to content

Commit d6b394e

Browse files
authored
Merge pull request #17 from JuliaAI/predict
Use "algorithm" instead of "model" and dispatch `predict` on kind of target proxy
2 parents 3711597 + cb45f87 commit d6b394e

29 files changed

+882
-890
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Anthony D. Blaom <[email protected]>"]
44
version = "0.1.0"
55

66
[deps]
7+
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
78
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
89

910
[extras]

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# LearnAPI.jl
22

3-
A Julia interface for training and applying machine learning models.
3+
A base Julia interface for machine learning and statistics
44

55

66
**Devlopement Status:**

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ makedocs(;
1616
"Predict and other operations" => "operations.md",
1717
"Accessor Functions" => "accessor_functions.md",
1818
"Optional Data Interface" => "optional_data_interface.md",
19-
"Model Traits" => "model_traits.md",
19+
"Algorithm Traits" => "algorithm_traits.md",
2020
"Common Implementation Patterns" => "common_implementation_patterns.md",
2121
"Testing an Implementation" => "testing_an_implementation.md",
2222
],

docs/src/accessor_functions.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Accessor Functions
22

33
> **Summary.** While byproducts of training are ordinarily recorded in the `report`
4-
> component of the output of `fit`/`update!`/`ingest!`, some families of models report an
5-
> item that is likely shared by multiple model types, and it is useful to have common
4+
> component of the output of `fit`/`update!`/`ingest!`, some families of algorithms report an
5+
> item that is likely shared by multiple algorithm types, and it is useful to have common
66
> interface for accessing these directly. Training losses and feature importances are two
77
> examples.
88

docs/src/algorithm_traits.md

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Algorithm Traits
2+
3+
> **Summary.** Traits allow one to promise particular behaviour for an algorithm, such as:
4+
> *This algorithm supports per-observation weights, which must appear as the third
5+
> argument of `fit`*, or *This algorithm's `transform` method predicts `Real` vectors*.
6+
7+
For any (non-trivial) algorithm, [`LearnAPI.functions`](@ref)`(algorithm)` must be
8+
overloaded to list the LearnAPI methods that have been explicitly implemented/overloaded
9+
(algorithm traits excluded). Overloading other traits is optional, except where required
10+
by the implementation of some LearnAPI method and explicitly documented in that method's
11+
docstring.
12+
13+
Traits are often called on instances but are usually *defined* on algorithm *types*, as in
14+
15+
```julia
16+
LearnAPI.is_pure_julia(::Type{<:MyAlgorithmType}) = true
17+
```
18+
19+
which has the shorthand
20+
21+
```julia
22+
@trait MyAlgorithmType is_pure_julia=true
23+
```
24+
25+
So, for convenience, every trait `t` is provided the fallback implementation
26+
27+
```julia
28+
t(algorithm) = t(typeof(algorithm))
29+
```
30+
31+
This means `LearnAPI.is_pure_julia(algorithm) = true` whenever `algorithm isa MyAlgorithmType` in the
32+
above example.
33+
34+
Every trait has a global fallback implementation for `::Type`. See the table below.
35+
36+
## When traits depdend on more than algorithm type
37+
38+
Traits that vary from instance to instance of the same type are disallowed, except in the
39+
case of composite algorithms (`is_wrapper(algorithm) = true`) where this is typically
40+
unavoidable. The reason for this is so one can associate, with each non-composite
41+
algorithm type, unique trait-based "algorithm metadata", for inclusion in searchable
42+
algorithm databases. This requirement occasionally requires that an existing algorithm
43+
implementation be split into separate LearnAPI implementations (e.g., one for regression
44+
and another for classification).
45+
46+
## Special two-argument traits
47+
48+
The two-argument version of [`LearnAPI.predict_output_scitype`](@ref) and
49+
[`LearnAPI.predict_output_scitype`](@ref) are the only overloadable traits with more than
50+
one argument. They cannot be declared using the `@trait` macro.
51+
52+
## Trait summary
53+
54+
**Overloadable traits** are available for overloading by any new LearnAPI
55+
implementation. **Derived traits** are not, and should not be called by performance
56+
critical code
57+
58+
## Overloadable traits
59+
60+
In the examples column of the table below, `Table`, `Continuous`, `Sampleable` are names owned by the
61+
package [ScientificTypesBase.jl](https://github.com/JuliaAI/ScientificTypesBase.jl/).
62+
63+
| trait | fallback value | return value | example |
64+
|:-------------------------------------------------|:----------------------|:--------------|:--------|
65+
| [`LearnAPI.functions`](@ref)`(algorithm)` | `()` | implemented LearnAPI functions (traits excluded) | `(:fit, :predict)` |
66+
| [`LearnAPI.preferred_kind_of_proxy`](@ref)`(algorithm)` | `LearnAPI.None()` | an instance `tp` of `KindOfProxy` for which an implementation of `LearnAPI.predict(algorithm, tp, ...)` is guaranteed. | `LearnAPI.Distribution()` |
67+
| [`LearnAPI.position_of_target`](@ref)`(algorithm)` | `0` | ¹ the positional index of the **target** in `data` in `fit(..., data...; metadata)` calls | 2 |
68+
| [`LearnAPI.position_of_weights`](@ref)`(algorithm)` | `0` | ¹ the positional index of **per-observation weights** in `data` in `fit(..., data...; metadata)` | 3 |
69+
| [`LearnAPI.descriptors`](@ref)`(algorithm)` | `()` | lists one or more suggestive algorithm descriptors from `LearnAPI.descriptors()` | (:classifier, :probabilistic) |
70+
| [`LearnAPI.is_pure_julia`](@ref)`(algorithm)` | `false` | is `true` if implementation is 100% Julia code | `true` |
71+
| [`LearnAPI.pkg_name`](@ref)`(algorithm)` | `"unknown"` | name of package providing core code (may be different from package providing LearnAPI.jl implementation) | `"DecisionTree"` |
72+
| [`LearnAPI.pkg_license`](@ref)`(algorithm)` | `"unknown"` | name of license of package providing core code | `"MIT"` |
73+
| [`LearnAPI.doc_url`](@ref)`(algorithm)` | `"unknown"` | url providing documentation of the core code | `"https://en.wikipedia.org/wiki/Decision_tree_learning"` |
74+
| [`LearnAPI.load_path`](@ref)`(algorithm)` | `"unknown"` | a string indicating where the struct for `typeof(algorithm)` is defined, beginning with name of package providing implementation | `FastTrees.LearnAPI.DecisionTreeClassifier` |
75+
| [`LearnAPI.is_wrapper`](@ref)`(algorithm)` | `false` | is `true` if one or more properties (fields) of `algorithm` may be an algorithm | `true` |
76+
| [`LearnAPI.human_name`](@ref)`(algorithm)` | type name with spaces | human name for the algorithm; should be a noun | "elastic net regressor" |
77+
| [`LearnAPI.iteration_parameter`](@ref)`(algorithm)` | `nothing` | symbolic name of an iteration parameter | :epochs |
78+
| [`LearnAPI.fit_keywords`](@ref)`(algorithm)` | `()` | tuple of symbols for keyword arguments accepted by `fit` (corresponding to metadata) | `(:class_weights,)` |
79+
| [`LearnAPI.fit_scitype`](@ref)`(algorithm)` | `Union{}` | upper bound on `scitype(data)` in `fit(algorithm, verbosity, data...)`² | `Tuple{Table(Continuous), AbstractVector{Continuous}}` |
80+
| [`LearnAPI.fit_observation_scitype`](@ref)`(algorithm)` | `Union{}`| upper bound on `scitype(observation)` for `observation` in `data` and `data` in `fit(algorithm, verbosity, data...)`² | `Tuple{AbstractVector{Continuous}, Continuous}` |
81+
| [`LearnAPI.fit_type`](@ref)`(algorithm)` | `Union{}` | upper bound on `type(data)` in `fit(algorithm, verbosity, data...)`² | `Tuple{AbstractMatrix{<:Real}, AbstractVector{<:Real}}` |
82+
| [`LearnAPI.fit_observation_type`](@ref)`(algorithm)` | `Union{}`| upper bound on `type(observation)` for `observation` in `data` and `data` in `fit(algorithm, verbosity, data...)`* | `Tuple{AbstractVector{<:Real}, Real}` |
83+
| [`LearnAPI.predict_input_scitype`](@ref)`(algorithm)` | `Union{}` | upper bound on `scitype(data)` in `predict(algorithm, fitted_params, data...)`² | `Table(Continuous)` |
84+
| [`LearnAPI.predict_output_scitype`](@ref)`(algorithm, kind_of_proxy)` | `Any` | upper bound on `scitype(first(predict(algorithm, kind_of_proxy, ...)))` | `AbstractVector{Continuous}` |
85+
| [`LearnAPI.predict_input_type`](@ref)`(algorithm)` | `Union{}` | upper bound on `typeof(data)` in `predict(algorithm, fitted_params, data...)`² | `AbstractMatrix{<:Real}` |
86+
| [`LearnAPI.predict_output_type`](@ref)`(algorithm, kind_of_proxy)` | `Any` | upper bound on `typeof(first(predict(algorithm, kind_of_proxy, ...)))` | `AbstractVector{<:Real}` |
87+
| [`LearnAPI.transform_input_scitype`](@ref)`(algorithm)` | `Union{}` | upper bound on `scitype(data)` in `transform(algorithm, fitted_params, data...)`² | `Table(Continuous)` |
88+
| [`LearnAPI.transform_output_scitype`](@ref)`(algorithm)` | `Any` | upper bound on `scitype(first(transform(algorithm, ...)))` | `Table(Continuous)` |
89+
| [`LearnAPI.transform_input_type`](@ref)`(algorithm)` | `Union{}` | upper bound on `typeof(data)` in `transform(algorithm, fitted_params, data...)`² | `AbstractMatrix{<:Real}}` |
90+
| [`LearnAPI.transform_output_type`](@ref)`(algorithm)` | `Any` | upper bound on `typeof(first(transform(algorithm, ...)))` | `AbstractMatrix{<:Real}` |
91+
92+
¹ If the value is `0`, then the variable in boldface type is not supported and not
93+
expected to appear in `data`. If `length(data)` is less than the trait value, then `data`
94+
is understood to exclude the variable, but note that `fit` can have multiple signatures of
95+
varying lengths, as in `fit(algorithm, verbosity, X, y)` and `fit(algorithm, verbosity, X, y,
96+
w)`. A non-zero value is a promise that `fit` includes a signature of sufficient length to
97+
include the variable.
98+
99+
² Assuming no [optional data interface](@ref data_interface) is implemented. See docstring
100+
for the general case.
101+
102+
103+
## Derived Traits
104+
105+
The following convenience methods are provided but intended for overloading:
106+
107+
| trait | return value | example |
108+
|:-------------------------------------|:------------------------------------------|:-----------|
109+
| `LearnAPI.name(algorithm)` | algorithm type name as string | "PCA" |
110+
| `LearnAPI.is_algorithm(algorithm)` | `true` if `functions(algorithm)` is not empty | `true` |
111+
| [`LearnAPI.predict_output_scitype`](@ref)(algorithm) | dictionary of upper bounds on the scitype of predictions, keyed on subtypes of [`LearnAPI.KindOfProxy`](@ref) |
112+
| [`LearnAPI.predict_output_type`](@ref)(algorithm) | dictionary of upper bounds on the type of predictions, keyed on subtypes of [`LearnAPI.KindOfProxy`](@ref) |
113+
114+
115+
## Reference
116+
117+
```@docs
118+
LearnAPI.functions
119+
LearnAPI.preferred_kind_of_proxy
120+
LearnAPI.position_of_target
121+
LearnAPI.position_of_weights
122+
LearnAPI.descriptors
123+
LearnAPI.is_pure_julia
124+
LearnAPI.pkg_name
125+
LearnAPI.pkg_license
126+
LearnAPI.doc_url
127+
LearnAPI.load_path
128+
LearnAPI.is_wrapper
129+
LearnAPI.fit_keywords
130+
LearnAPI.human_name
131+
LearnAPI.iteration_parameter
132+
LearnAPI.fit_scitype
133+
LearnAPI.fit_type
134+
LearnAPI.fit_observation_scitype
135+
LearnAPI.fit_observation_type
136+
LearnAPI.predict_input_scitype
137+
LearnAPI.predict_output_scitype
138+
LearnAPI.predict_input_type
139+
LearnAPI.predict_output_type
140+
LearnAPI.transform_input_scitype
141+
LearnAPI.transform_output_scitype
142+
LearnAPI.transform_input_type
143+
LearnAPI.transform_output_type
144+
```

0 commit comments

Comments
 (0)