From a683c9325cfd318cbfe4d96a64fa3977d89a9af5 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Sat, 12 Oct 2024 11:06:36 +1300 Subject: [PATCH 1/3] add a observation-updatable density estimator to tests --- Project.toml | 2 + docs/src/common_implementation_patterns.md | 8 +- docs/src/patterns/density_estimation.md | 4 + docs/src/patterns/incremental_algorithms.md | 5 + src/predict_transform.jl | 4 + src/types.jl | 54 ++++---- test/patterns/incremental_algorithms.jl | 135 ++++++++++++++++++++ test/runtests.jl | 1 + test/traits.jl | 5 +- 9 files changed, 185 insertions(+), 33 deletions(-) create mode 100644 docs/src/patterns/incremental_algorithms.md create mode 100644 test/patterns/incremental_algorithms.jl diff --git a/Project.toml b/Project.toml index 849adaeb..2d23d7e2 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ julia = "1.6" [extras] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -23,6 +24,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = [ "DataFrames", + "Distributions", "LinearAlgebra", "MLUtils", "Random", diff --git a/docs/src/common_implementation_patterns.md b/docs/src/common_implementation_patterns.md index 5ab63cce..7f7641a4 100644 --- a/docs/src/common_implementation_patterns.md +++ b/docs/src/common_implementation_patterns.md @@ -1,8 +1,6 @@ # Common Implementation Patterns -```@raw html -🚧 -``` +!!! warning This section is only an implementation guide. The definitive specification of the Learn API is given in [Reference](@ref reference). @@ -25,7 +23,7 @@ implementations fall into one (or more) of the following informally understood p - [Iterative Algorithms](@ref) -- Incremental Algorithms +- [Incremental Algorithms](@ref): Algorithms that can be updated with new observations. - [Feature Engineering](@ref): Algorithms for selecting or combining features @@ -48,7 +46,7 @@ implementations fall into one (or more) of the following informally understood p - Survival Analysis -- Density Estimation: Algorithms that learn a probability distribution +- [Density Estimation](@ref): Algorithms that learn a probability distribution - Bayesian Algorithms diff --git a/docs/src/patterns/density_estimation.md b/docs/src/patterns/density_estimation.md index f535f9fe..e9ca083b 100644 --- a/docs/src/patterns/density_estimation.md +++ b/docs/src/patterns/density_estimation.md @@ -1 +1,5 @@ # Density Estimation + +See these examples from tests: + +- [normal distribution estimator](https://github.com/JuliaAI/LearnAPI.jl/blob/dev/test/patterns/incremental_algorithms.jl) diff --git a/docs/src/patterns/incremental_algorithms.md b/docs/src/patterns/incremental_algorithms.md new file mode 100644 index 00000000..89ad8643 --- /dev/null +++ b/docs/src/patterns/incremental_algorithms.md @@ -0,0 +1,5 @@ +# Incremental Algorithms + +See these examples from tests: + +- [normal distribution estimator](https://github.com/JuliaAI/LearnAPI.jl/blob/dev/test/patterns/incremental_algorithms.jl) diff --git a/src/predict_transform.jl b/src/predict_transform.jl index d59ac78e..39bff2a9 100644 --- a/src/predict_transform.jl +++ b/src/predict_transform.jl @@ -66,6 +66,9 @@ which lists all supported target proxies. The argument `model` is anything returned by a call of the form `fit(algorithm, ...)`. +If `LearnAPI.features(LearnAPI.algorithm(model)) == nothing`, then argument `data` is +omitted. An example is density estimators. + # Example In the following, `algorithm` is some supervised learning algorithm with @@ -105,6 +108,7 @@ $(DOC_DATA_INTERFACE(:predict)) """ predict(model, data) = predict(model, kinds_of_proxy(algorithm(model)) |> first, data) +predict(model) = predict(model, kinds_of_proxy(algorithm(model)) |> first) # automatic slurping of multiple data arguments: predict(model, k::KindOfProxy, data1, data2, datas...; kwargs...) = diff --git a/src/types.jl b/src/types.jl index 25f98d81..be40922f 100644 --- a/src/types.jl +++ b/src/types.jl @@ -22,27 +22,27 @@ See also [`LearnAPI.KindOfProxy`](@ref). | type | form of an observation | |:-------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `LearnAPI.Point` | same as target observations; may have the interpretation of a 50% quantile, 50% expectile or mode | -| `LearnAPI.Sampleable` | object that can be sampled to obtain object of the same form as target observation | -| `LearnAPI.Distribution` | explicit probability density/mass function whose sample space is all possible target observations | -| `LearnAPI.LogDistribution` | explicit log-probability density/mass function whose sample space is possible target observations | -| `LearnAPI.Probability`¹ | numerical probability or probability vector | -| `LearnAPI.LogProbability`¹ | log-probability or log-probability vector | -| `LearnAPI.Parametric`¹ | a list of parameters (e.g., mean and variance) describing some distribution | -| `LearnAPI.LabelAmbiguous` | collections of labels (in case of multi-class target) but without a known correspondence to the original target labels (and of possibly different number) as in, e.g., clustering | -| `LearnAPI.LabelAmbiguousSampleable` | sampleable version of `LabelAmbiguous`; see `Sampleable` above | -| `LearnAPI.LabelAmbiguousDistribution` | pdf/pmf version of `LabelAmbiguous`; see `Distribution` above | -| `LearnAPI.LabelAmbiguousFuzzy` | same as `LabelAmbiguous` but with multiple values of indeterminant number | -| `LearnAPI.Quantile`² | same as target but with quantile interpretation | -| `LearnAPI.Expectile`² | same as target but with expectile interpretation | -| `LearnAPI.ConfidenceInterval`² | confidence interval | -| `LearnAPI.Fuzzy` | finite but possibly varying number of target observations | -| `LearnAPI.ProbabilisticFuzzy` | as for `Fuzzy` but labeled with probabilities (not necessarily summing to one) | -| `LearnAPI.SurvivalFunction` | survival function | -| `LearnAPI.SurvivalDistribution` | probability distribution for survival time | -| `LearnAPI.SurvivalHazardFunction` | hazard function for survival time | -| `LearnAPI.OutlierScore` | numerical score reflecting degree of outlierness (not necessarily normalized) | -| `LearnAPI.Continuous` | real-valued approximation/interpolation of a discrete-valued target, such as a count (e.g., number of phone calls) | +| `Point` | same as target observations; may have the interpretation of a 50% quantile, 50% expectile or mode | +| `Sampleable` | object that can be sampled to obtain object of the same form as target observation | +| `Distribution` | explicit probability density/mass function whose sample space is all possible target observations | +| `LogDistribution` | explicit log-probability density/mass function whose sample space is possible target observations | +| `Probability`¹ | numerical probability or probability vector | +| `LogProbability`¹ | log-probability or log-probability vector | +| `Parametric`¹ | a list of parameters (e.g., mean and variance) describing some distribution | +| `LabelAmbiguous` | collections of labels (in case of multi-class target) but without a known correspondence to the original target labels (and of possibly different number) as in, e.g., clustering | +| `LabelAmbiguousSampleable` | sampleable version of `LabelAmbiguous`; see `Sampleable` above | +| `LabelAmbiguousDistribution` | pdf/pmf version of `LabelAmbiguous`; see `Distribution` above | +| `LabelAmbiguousFuzzy` | same as `LabelAmbiguous` but with multiple values of indeterminant number | +| `Quantile`² | same as target but with quantile interpretation | +| `Expectile`² | same as target but with expectile interpretation | +| `ConfidenceInterval`² | confidence interval | +| `Fuzzy` | finite but possibly varying number of target observations | +| `ProbabilisticFuzzy` | as for `Fuzzy` but labeled with probabilities (not necessarily summing to one) | +| `SurvivalFunction` | survival function | +| `SurvivalDistribution` | probability distribution for survival time | +| `SurvivalHazardFunction` | hazard function for survival time | +| `OutlierScore` | numerical score reflecting degree of outlierness (not necessarily normalized) | +| `Continuous` | real-valued approximation/interpolation of a discrete-valued target, such as a count (e.g., number of phone calls) | ¹Provided for completeness but discouraged to avoid [ambiguities in representation](https://github.com/alan-turing-institute/MLJ.jl/blob/dev/paper/paper.md#a-unified-approach-to-probabilistic-predictions-and-their-evaluation). @@ -86,9 +86,9 @@ space ``Y^n``, where ``Y`` is the space from which the target variable takes its | type `T` | form of output of `predict(model, ::T, data)` | |:-------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `LearnAPI.JointSampleable` | object that can be sampled to obtain a *vector* whose elements have the form of target observations; the vector length matches the number of observations in `data`. | -| `LearnAPI.JointDistribution` | explicit probability density/mass function whose sample space is vectors of target observations; the vector length matches the number of observations in `data` | -| `LearnAPI.JointLogDistribution` | explicit log-probability density/mass function whose sample space is vectors of target observations; the vector length matches the number of observations in `data` | +| `JointSampleable` | object that can be sampled to obtain a *vector* whose elements have the form of target observations; the vector length matches the number of observations in `data`. | +| `JointDistribution` | explicit probability density/mass function whose sample space is vectors of target observations; the vector length matches the number of observations in `data` | +| `JointLogDistribution` | explicit log-probability density/mass function whose sample space is vectors of target observations; the vector length matches the number of observations in `data` | """ abstract type Joint <: KindOfProxy end @@ -108,9 +108,9 @@ single object representing a probability distribution. | type `T` | form of output of `predict(model, ::T)` | |:--------------------------------:|:-----------------------------------------------------------------------| -| `LearnAPI.SingleSampleable` | object that can be sampled to obtain a single target observation | -| `LearnAPI.SingleDistribution` | explicit probability density/mass function for sampling the target | -| `LearnAPI.SingleLogDistribution` | explicit log-probability density/mass function for sampling the target | +| `SingleSampleable` | object that can be sampled to obtain a single target observation | +| `SingleDistribution` | explicit probability density/mass function for sampling the target | +| `SingleLogDistribution` | explicit log-probability density/mass function for sampling the target | """ abstract type Single <: KindOfProxy end diff --git a/test/patterns/incremental_algorithms.jl b/test/patterns/incremental_algorithms.jl new file mode 100644 index 00000000..71f9bb26 --- /dev/null +++ b/test/patterns/incremental_algorithms.jl @@ -0,0 +1,135 @@ +using LearnAPI +using Statistics +using StableRNGs + +import Distributions + +# # NORMAL DENSITY ESTIMATOR + +# An example of density estimation and also of incremental learning +# (`update_observations`). + + +# ## Implementation + +""" + NormalEstimator() + +Instantiate an algorithm for finding the maximum likelihood normal distribution fitting +some real univariate data `y`. Estimates can be updated with new data. + +```julia +model = fit(NormalEstimator(), y) +d = predict(model) # returns the learned `Normal` distribution +``` + +While the above is equivalent to the single operation `d = +predict(NormalEstimator(), y)`, the above workflow allows for the presentation of +additional observations post facto: The following is equivalent to `d2 = +predict(NormalEstimator(), vcat(y, ynew))`: + +```julia +update_observations(model, ynew) +d2 = predict(model) +``` + +Inspect all learned parameters with `LearnAPI.extras(model)`. Predict a 95% +confidence interval with `predict(model, ConfidenceInterval())` + +""" +struct NormalEstimator end + +struct NormalEstimatorFitted{T} + Σy::T + ȳ::T + ss::T # sum of squared residuals + n::Int +end + +LearnAPI.algorithm(::NormalEstimatorFitted) = NormalEstimator() + +function LearnAPI.fit(::NormalEstimator, y) + n = length(y) + Σy = sum(y) + ȳ = Σy/n + ss = sum(x->x^2, y) - n*ȳ^2 + return NormalEstimatorFitted(Σy, ȳ, ss, n) +end + +function LearnAPI.update_observations(model::NormalEstimatorFitted, ynew) + m = length(ynew) + n = model.n + m + Σynew = sum(ynew) + Σy = model.Σy + Σynew + ȳ = Σy/n + δ = model.n*((m*model.ȳ - Σynew)/n)^2 + ss = model.ss + δ + sum(x -> (x - ȳ)^2, ynew) + return NormalEstimatorFitted(Σy, ȳ, ss, n) +end + +LearnAPI.features(::NormalEstimator, y) = nothing +LearnAPI.target(::NormalEstimator, y) = y + +LearnAPI.predict(model::NormalEstimatorFitted, ::Distribution) = + Distributions.Normal(model.ȳ, sqrt(model.ss/model.n)) +LearnAPI.predict(model::NormalEstimatorFitted, ::Point) = model.ȳ +function LearnAPI.predict(model::NormalEstimatorFitted, ::ConfidenceInterval) + d = predict(model, Distribution()) + return (quantile(d, 0.025), quantile(d, 0.975)) +end + +# for fit and predict in one line: +LearnAPI.predict(::NormalEstimator, k::LearnAPI.KindOfProxy, y) = + predict(fit(NormalEstimator(), y), k) +LearnAPI.predict(::NormalEstimator, y) = predict(NormalEstimator(), Distribution(), y) + +LearnAPI.extras(model::NormalEstimatorFitted) = (μ=model.ȳ, σ=sqrt(model.ss/model.n)) + +@trait( + NormalEstimator, + constructor = NormalEstimator, + kinds_of_proxy = (Distribution(), Point(), ConfidenceInterval()), + tags = ("density estimation", "incremental algorithms"), + is_pure_julia = true, + human_name = "normal distribution estimator", + functions = ( + :(LearnAPI.fit), + :(LearnAPI.algorithm), + :(LearnAPI.strip), + :(LearnAPI.obs), + :(LearnAPI.features), + :(LearnAPI.target), + :(LearnAPI.predict), + :(LearnAPI.update_observations), + :(LearnAPI.extras), + ), +) + +# ## Tests + +@testset "NormalEstimator" begin + rng = StableRNG(123) + y = rand(rng, 50); + ynew = rand(rng, 10); + algorithm = NormalEstimator() + model = fit(algorithm, y) + d = predict(model) + μ, σ = Distributions.params(d) + @test μ ≈ mean(y) + @test σ ≈ std(y)*sqrt(49/50) # `std` uses Bessel's correction + + # accessor function: + @test LearnAPI.extras(model) == (; μ, σ) + + # one-liner: + @test predict(algorithm, y) == d + @test predict(algorithm, Point(), y) ≈ μ + @test predict(algorithm, ConfidenceInterval(), y)[1] ≈ quantile(d, 0.025) + + # updating: + model = update_observations(model, ynew) + μ2, σ2 = LearnAPI.extras(model) + μ3, σ3 = LearnAPI.extras(fit(algorithm, vcat(y, ynew))) # training ab initio + @test μ2 ≈ μ3 + @test σ2 ≈ σ3 +end diff --git a/test/runtests.jl b/test/runtests.jl index 5385a731..63cdfe62 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,7 @@ test_files = [ "patterns/regression.jl", "patterns/static_algorithms.jl", "patterns/ensembling.jl", + "patterns/incremental_algorithms.jl", ] files = isempty(ARGS) ? test_files : ARGS diff --git a/test/traits.jl b/test/traits.jl index a0c8a3d9..ab4cad1a 100644 --- a/test/traits.jl +++ b/test/traits.jl @@ -13,6 +13,9 @@ LearnAPI.algorithm(model::SmallAlgorithm) = model functions = ( :(LearnAPI.fit), :(LearnAPI.algorithm), + :(LearnAPI.strip), + :(LearnAPI.obs), + :(LearnAPI.features), ), ) ######## END OF IMPLEMENTATION ################## @@ -27,7 +30,7 @@ LearnAPI.algorithm(model::SmallAlgorithm) = model small = SmallAlgorithm() @test LearnAPI.constructor(small) == SmallAlgorithm -@test LearnAPI.functions(small) == (:(LearnAPI.fit), :(LearnAPI.algorithm)) +@test :(LearnAPI.algorithm) in LearnAPI.functions(small) @test isempty(LearnAPI.kinds_of_proxy(small)) @test isempty(LearnAPI.tags(small)) @test !LearnAPI.is_pure_julia(small) From 9b9e4d45a91954e14230b4073092208bc391170e Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Sat, 12 Oct 2024 11:13:34 +1300 Subject: [PATCH 2/3] typo --- docs/src/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index 36e361d2..959d1dc7 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -98,7 +98,7 @@ A key to enabling toolboxes to enhance LearnAPI.jl algorithm functionality is th implementation of two key additional methods, beyond the usual `fit` and `predict`/`transform`. Given any training `data` consumed by `fit` (such as `data = (X, y)` in the example above) [`LearnAPI.features(algorithm, data)`](@ref input) tells us what -part of `data` comprises *features*, which is something that can be passsed onto to +part of `data` comprises *features*, which is something that can be passed onto to `predict` or `transform` (`X` in the example) while [`LearnAPI.target(algorithm, data)`](@ref), if implemented, tells us what part comprises the target (`y` in the example). By explicitly requiring such methods, we free algorithms to consume data in From d82eaa5a4572a07120bce19e5586d9aa6ef888be Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Sat, 12 Oct 2024 12:02:15 +1300 Subject: [PATCH 3/3] add a test for predict and transform slurping fallbacks oops --- src/predict_transform.jl | 8 ++------ test/predict_transform.jl | 19 +++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 22 insertions(+), 6 deletions(-) create mode 100644 test/predict_transform.jl diff --git a/src/predict_transform.jl b/src/predict_transform.jl index 39bff2a9..1fff01f3 100644 --- a/src/predict_transform.jl +++ b/src/predict_transform.jl @@ -4,10 +4,6 @@ function DOC_IMPLEMENTED_METHODS(name; overloaded=false) "[`LearnAPI.functions`](@ref) trait. " end -const OPERATIONS = (:predict, :transform, :inverse_transform) -const DOC_OPERATIONS_LIST_SYMBOL = join(map(op -> "`:$op`", OPERATIONS), ", ") -const DOC_OPERATIONS_LIST_FUNCTION = join(map(op -> "`LearnAPI.$op`", OPERATIONS), ", ") - DOC_MUTATION(op) = """ @@ -171,8 +167,8 @@ $(DOC_MUTATION(:transform)) $(DOC_DATA_INTERFACE(:transform)) """ -transform(model, data1, data2...; kwargs...) = - transform(model, (data1, datas...); kwargs...) # automatic slurping +transform(model, data1, data2, datas...; kwargs...) = + transform(model, (data1, data2, datas...); kwargs...) # automatic slurping """ inverse_transform(model, data) diff --git a/test/predict_transform.jl b/test/predict_transform.jl new file mode 100644 index 00000000..7a496115 --- /dev/null +++ b/test/predict_transform.jl @@ -0,0 +1,19 @@ +using Test +using LearnAPI + +struct Goose end + +LearnAPI.fit(algorithm::Goose) = Ref(algorithm) +LearnAPI.algorithm(::Base.RefValue{Goose}) = Goose() +LearnAPI.predict(::Base.RefValue{Goose}, ::Point, data) = sum(data) +LearnAPI.transform(::Base.RefValue{Goose}, data) = prod(data) +@trait Goose kinds_of_proxy = (Point(),) + +@testset "predict and transform argument slurping" begin + model = fit(Goose()) + @test predict(model, Point(), 2, 3, 4) == 9 + @test predict(model, 2, 3, 4) == 9 + @test transform(model, 2, 3, 4) == 24 +end + +true diff --git a/test/runtests.jl b/test/runtests.jl index 63cdfe62..9af76002 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ test_files = [ "tools.jl", "traits.jl", "clone.jl", + "predict_transform.jl", "patterns/regression.jl", "patterns/static_algorithms.jl", "patterns/ensembling.jl",