Skip to content

Commit a683c93

Browse files
committed
add a observation-updatable density estimator to tests
1 parent 168e0c6 commit a683c93

File tree

9 files changed

+185
-33
lines changed

9 files changed

+185
-33
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ julia = "1.6"
1111

1212
[extras]
1313
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
14+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1415
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1516
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1617
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -23,6 +24,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2324
[targets]
2425
test = [
2526
"DataFrames",
27+
"Distributions",
2628
"LinearAlgebra",
2729
"MLUtils",
2830
"Random",

docs/src/common_implementation_patterns.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# Common Implementation Patterns
22

3-
```@raw html
4-
🚧
5-
```
3+
!!! warning
64

75
This section is only an implementation guide. The definitive specification of the
86
Learn API is given in [Reference](@ref reference).
@@ -25,7 +23,7 @@ implementations fall into one (or more) of the following informally understood p
2523

2624
- [Iterative Algorithms](@ref)
2725

28-
- Incremental Algorithms
26+
- [Incremental Algorithms](@ref): Algorithms that can be updated with new observations.
2927

3028
- [Feature Engineering](@ref): Algorithms for selecting or combining features
3129

@@ -48,7 +46,7 @@ implementations fall into one (or more) of the following informally understood p
4846

4947
- Survival Analysis
5048

51-
- Density Estimation: Algorithms that learn a probability distribution
49+
- [Density Estimation](@ref): Algorithms that learn a probability distribution
5250

5351
- Bayesian Algorithms
5452

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
11
# Density Estimation
2+
3+
See these examples from tests:
4+
5+
- [normal distribution estimator](https://github.com/JuliaAI/LearnAPI.jl/blob/dev/test/patterns/incremental_algorithms.jl)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Incremental Algorithms
2+
3+
See these examples from tests:
4+
5+
- [normal distribution estimator](https://github.com/JuliaAI/LearnAPI.jl/blob/dev/test/patterns/incremental_algorithms.jl)

src/predict_transform.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ which lists all supported target proxies.
6666
6767
The argument `model` is anything returned by a call of the form `fit(algorithm, ...)`.
6868
69+
If `LearnAPI.features(LearnAPI.algorithm(model)) == nothing`, then argument `data` is
70+
omitted. An example is density estimators.
71+
6972
# Example
7073
7174
In the following, `algorithm` is some supervised learning algorithm with
@@ -105,6 +108,7 @@ $(DOC_DATA_INTERFACE(:predict))
105108
106109
"""
107110
predict(model, data) = predict(model, kinds_of_proxy(algorithm(model)) |> first, data)
111+
predict(model) = predict(model, kinds_of_proxy(algorithm(model)) |> first)
108112

109113
# automatic slurping of multiple data arguments:
110114
predict(model, k::KindOfProxy, data1, data2, datas...; kwargs...) =

src/types.jl

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,27 @@ See also [`LearnAPI.KindOfProxy`](@ref).
2222
2323
| type | form of an observation |
2424
|:-------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
25-
| `LearnAPI.Point` | same as target observations; may have the interpretation of a 50% quantile, 50% expectile or mode |
26-
| `LearnAPI.Sampleable` | object that can be sampled to obtain object of the same form as target observation |
27-
| `LearnAPI.Distribution` | explicit probability density/mass function whose sample space is all possible target observations |
28-
| `LearnAPI.LogDistribution` | explicit log-probability density/mass function whose sample space is possible target observations |
29-
| `LearnAPI.Probability`¹ | numerical probability or probability vector |
30-
| `LearnAPI.LogProbability`¹ | log-probability or log-probability vector |
31-
| `LearnAPI.Parametric`¹ | a list of parameters (e.g., mean and variance) describing some distribution |
32-
| `LearnAPI.LabelAmbiguous` | collections of labels (in case of multi-class target) but without a known correspondence to the original target labels (and of possibly different number) as in, e.g., clustering |
33-
| `LearnAPI.LabelAmbiguousSampleable` | sampleable version of `LabelAmbiguous`; see `Sampleable` above |
34-
| `LearnAPI.LabelAmbiguousDistribution` | pdf/pmf version of `LabelAmbiguous`; see `Distribution` above |
35-
| `LearnAPI.LabelAmbiguousFuzzy` | same as `LabelAmbiguous` but with multiple values of indeterminant number |
36-
| `LearnAPI.Quantile`² | same as target but with quantile interpretation |
37-
| `LearnAPI.Expectile`² | same as target but with expectile interpretation |
38-
| `LearnAPI.ConfidenceInterval`² | confidence interval |
39-
| `LearnAPI.Fuzzy` | finite but possibly varying number of target observations |
40-
| `LearnAPI.ProbabilisticFuzzy` | as for `Fuzzy` but labeled with probabilities (not necessarily summing to one) |
41-
| `LearnAPI.SurvivalFunction` | survival function |
42-
| `LearnAPI.SurvivalDistribution` | probability distribution for survival time |
43-
| `LearnAPI.SurvivalHazardFunction` | hazard function for survival time |
44-
| `LearnAPI.OutlierScore` | numerical score reflecting degree of outlierness (not necessarily normalized) |
45-
| `LearnAPI.Continuous` | real-valued approximation/interpolation of a discrete-valued target, such as a count (e.g., number of phone calls) |
25+
| `Point` | same as target observations; may have the interpretation of a 50% quantile, 50% expectile or mode |
26+
| `Sampleable` | object that can be sampled to obtain object of the same form as target observation |
27+
| `Distribution` | explicit probability density/mass function whose sample space is all possible target observations |
28+
| `LogDistribution` | explicit log-probability density/mass function whose sample space is possible target observations |
29+
| `Probability`¹ | numerical probability or probability vector |
30+
| `LogProbability`¹ | log-probability or log-probability vector |
31+
| `Parametric`¹ | a list of parameters (e.g., mean and variance) describing some distribution |
32+
| `LabelAmbiguous` | collections of labels (in case of multi-class target) but without a known correspondence to the original target labels (and of possibly different number) as in, e.g., clustering |
33+
| `LabelAmbiguousSampleable` | sampleable version of `LabelAmbiguous`; see `Sampleable` above |
34+
| `LabelAmbiguousDistribution` | pdf/pmf version of `LabelAmbiguous`; see `Distribution` above |
35+
| `LabelAmbiguousFuzzy` | same as `LabelAmbiguous` but with multiple values of indeterminant number |
36+
| `Quantile`² | same as target but with quantile interpretation |
37+
| `Expectile`² | same as target but with expectile interpretation |
38+
| `ConfidenceInterval`² | confidence interval |
39+
| `Fuzzy` | finite but possibly varying number of target observations |
40+
| `ProbabilisticFuzzy` | as for `Fuzzy` but labeled with probabilities (not necessarily summing to one) |
41+
| `SurvivalFunction` | survival function |
42+
| `SurvivalDistribution` | probability distribution for survival time |
43+
| `SurvivalHazardFunction` | hazard function for survival time |
44+
| `OutlierScore` | numerical score reflecting degree of outlierness (not necessarily normalized) |
45+
| `Continuous` | real-valued approximation/interpolation of a discrete-valued target, such as a count (e.g., number of phone calls) |
4646
4747
¹Provided for completeness but discouraged to avoid [ambiguities in
4848
representation](https://github.com/alan-turing-institute/MLJ.jl/blob/dev/paper/paper.md#a-unified-approach-to-probabilistic-predictions-and-their-evaluation).
@@ -86,9 +86,9 @@ space ``Y^n``, where ``Y`` is the space from which the target variable takes its
8686
8787
| type `T` | form of output of `predict(model, ::T, data)` |
8888
|:-------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------|
89-
| `LearnAPI.JointSampleable` | object that can be sampled to obtain a *vector* whose elements have the form of target observations; the vector length matches the number of observations in `data`. |
90-
| `LearnAPI.JointDistribution` | explicit probability density/mass function whose sample space is vectors of target observations; the vector length matches the number of observations in `data` |
91-
| `LearnAPI.JointLogDistribution` | explicit log-probability density/mass function whose sample space is vectors of target observations; the vector length matches the number of observations in `data` |
89+
| `JointSampleable` | object that can be sampled to obtain a *vector* whose elements have the form of target observations; the vector length matches the number of observations in `data`. |
90+
| `JointDistribution` | explicit probability density/mass function whose sample space is vectors of target observations; the vector length matches the number of observations in `data` |
91+
| `JointLogDistribution` | explicit log-probability density/mass function whose sample space is vectors of target observations; the vector length matches the number of observations in `data` |
9292
9393
"""
9494
abstract type Joint <: KindOfProxy end
@@ -108,9 +108,9 @@ single object representing a probability distribution.
108108
109109
| type `T` | form of output of `predict(model, ::T)` |
110110
|:--------------------------------:|:-----------------------------------------------------------------------|
111-
| `LearnAPI.SingleSampleable` | object that can be sampled to obtain a single target observation |
112-
| `LearnAPI.SingleDistribution` | explicit probability density/mass function for sampling the target |
113-
| `LearnAPI.SingleLogDistribution` | explicit log-probability density/mass function for sampling the target |
111+
| `SingleSampleable` | object that can be sampled to obtain a single target observation |
112+
| `SingleDistribution` | explicit probability density/mass function for sampling the target |
113+
| `SingleLogDistribution` | explicit log-probability density/mass function for sampling the target |
114114
115115
"""
116116
abstract type Single <: KindOfProxy end
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
using LearnAPI
2+
using Statistics
3+
using StableRNGs
4+
5+
import Distributions
6+
7+
# # NORMAL DENSITY ESTIMATOR
8+
9+
# An example of density estimation and also of incremental learning
10+
# (`update_observations`).
11+
12+
13+
# ## Implementation
14+
15+
"""
16+
NormalEstimator()
17+
18+
Instantiate an algorithm for finding the maximum likelihood normal distribution fitting
19+
some real univariate data `y`. Estimates can be updated with new data.
20+
21+
```julia
22+
model = fit(NormalEstimator(), y)
23+
d = predict(model) # returns the learned `Normal` distribution
24+
```
25+
26+
While the above is equivalent to the single operation `d =
27+
predict(NormalEstimator(), y)`, the above workflow allows for the presentation of
28+
additional observations post facto: The following is equivalent to `d2 =
29+
predict(NormalEstimator(), vcat(y, ynew))`:
30+
31+
```julia
32+
update_observations(model, ynew)
33+
d2 = predict(model)
34+
```
35+
36+
Inspect all learned parameters with `LearnAPI.extras(model)`. Predict a 95%
37+
confidence interval with `predict(model, ConfidenceInterval())`
38+
39+
"""
40+
struct NormalEstimator end
41+
42+
struct NormalEstimatorFitted{T}
43+
Σy::T
44+
::T
45+
ss::T # sum of squared residuals
46+
n::Int
47+
end
48+
49+
LearnAPI.algorithm(::NormalEstimatorFitted) = NormalEstimator()
50+
51+
function LearnAPI.fit(::NormalEstimator, y)
52+
n = length(y)
53+
Σy = sum(y)
54+
= Σy/n
55+
ss = sum(x->x^2, y) - n*^2
56+
return NormalEstimatorFitted(Σy, ȳ, ss, n)
57+
end
58+
59+
function LearnAPI.update_observations(model::NormalEstimatorFitted, ynew)
60+
m = length(ynew)
61+
n = model.n + m
62+
Σynew = sum(ynew)
63+
Σy = model.Σy + Σynew
64+
= Σy/n
65+
δ = model.n*((m*model.- Σynew)/n)^2
66+
ss = model.ss + δ + sum(x -> (x - ȳ)^2, ynew)
67+
return NormalEstimatorFitted(Σy, ȳ, ss, n)
68+
end
69+
70+
LearnAPI.features(::NormalEstimator, y) = nothing
71+
LearnAPI.target(::NormalEstimator, y) = y
72+
73+
LearnAPI.predict(model::NormalEstimatorFitted, ::Distribution) =
74+
Distributions.Normal(model.ȳ, sqrt(model.ss/model.n))
75+
LearnAPI.predict(model::NormalEstimatorFitted, ::Point) = model.
76+
function LearnAPI.predict(model::NormalEstimatorFitted, ::ConfidenceInterval)
77+
d = predict(model, Distribution())
78+
return (quantile(d, 0.025), quantile(d, 0.975))
79+
end
80+
81+
# for fit and predict in one line:
82+
LearnAPI.predict(::NormalEstimator, k::LearnAPI.KindOfProxy, y) =
83+
predict(fit(NormalEstimator(), y), k)
84+
LearnAPI.predict(::NormalEstimator, y) = predict(NormalEstimator(), Distribution(), y)
85+
86+
LearnAPI.extras(model::NormalEstimatorFitted) ==model.ȳ, σ=sqrt(model.ss/model.n))
87+
88+
@trait(
89+
NormalEstimator,
90+
constructor = NormalEstimator,
91+
kinds_of_proxy = (Distribution(), Point(), ConfidenceInterval()),
92+
tags = ("density estimation", "incremental algorithms"),
93+
is_pure_julia = true,
94+
human_name = "normal distribution estimator",
95+
functions = (
96+
:(LearnAPI.fit),
97+
:(LearnAPI.algorithm),
98+
:(LearnAPI.strip),
99+
:(LearnAPI.obs),
100+
:(LearnAPI.features),
101+
:(LearnAPI.target),
102+
:(LearnAPI.predict),
103+
:(LearnAPI.update_observations),
104+
:(LearnAPI.extras),
105+
),
106+
)
107+
108+
# ## Tests
109+
110+
@testset "NormalEstimator" begin
111+
rng = StableRNG(123)
112+
y = rand(rng, 50);
113+
ynew = rand(rng, 10);
114+
algorithm = NormalEstimator()
115+
model = fit(algorithm, y)
116+
d = predict(model)
117+
μ, σ = Distributions.params(d)
118+
@test μ mean(y)
119+
@test σ std(y)*sqrt(49/50) # `std` uses Bessel's correction
120+
121+
# accessor function:
122+
@test LearnAPI.extras(model) == (; μ, σ)
123+
124+
# one-liner:
125+
@test predict(algorithm, y) == d
126+
@test predict(algorithm, Point(), y) μ
127+
@test predict(algorithm, ConfidenceInterval(), y)[1] quantile(d, 0.025)
128+
129+
# updating:
130+
model = update_observations(model, ynew)
131+
μ2, σ2 = LearnAPI.extras(model)
132+
μ3, σ3 = LearnAPI.extras(fit(algorithm, vcat(y, ynew))) # training ab initio
133+
@test μ2 μ3
134+
@test σ2 σ3
135+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ test_files = [
77
"patterns/regression.jl",
88
"patterns/static_algorithms.jl",
99
"patterns/ensembling.jl",
10+
"patterns/incremental_algorithms.jl",
1011
]
1112

1213
files = isempty(ARGS) ? test_files : ARGS

test/traits.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ LearnAPI.algorithm(model::SmallAlgorithm) = model
1313
functions = (
1414
:(LearnAPI.fit),
1515
:(LearnAPI.algorithm),
16+
:(LearnAPI.strip),
17+
:(LearnAPI.obs),
18+
:(LearnAPI.features),
1619
),
1720
)
1821
######## END OF IMPLEMENTATION ##################
@@ -27,7 +30,7 @@ LearnAPI.algorithm(model::SmallAlgorithm) = model
2730

2831
small = SmallAlgorithm()
2932
@test LearnAPI.constructor(small) == SmallAlgorithm
30-
@test LearnAPI.functions(small) == (:(LearnAPI.fit), :(LearnAPI.algorithm))
33+
@test :(LearnAPI.algorithm) in LearnAPI.functions(small)
3134
@test isempty(LearnAPI.kinds_of_proxy(small))
3235
@test isempty(LearnAPI.tags(small))
3336
@test !LearnAPI.is_pure_julia(small)

0 commit comments

Comments
 (0)