diff --git a/docs/src/quick_start.md b/docs/src/quick_start.md index 078c401..6305c68 100644 --- a/docs/src/quick_start.md +++ b/docs/src/quick_start.md @@ -16,11 +16,11 @@ returning objects of type `MyModel`, make these declarations: ```julia using LearnDataFrontEnds -frontend = Saffron() +const frontend = Saffron() # both methods below return objects with abstract type `Obs`: LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend) -LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend) +LearnAPI.obs(model::MyModel, data) = obs(model, data, frontend) # training data deconstructors: LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend) @@ -72,11 +72,11 @@ returning objects of type `MyModel`, make these declarations: ```julia using LearnDataFrontEnds -frontend = Sage() +const frontend = Sage() # both methods below return objects with abstract type `Obs`: LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend) -LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend) +LearnAPI.obs(model::MyModel, data) = obs(model, data, frontend) # training data deconstructors: LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend) @@ -138,13 +138,13 @@ objects of type `MyModel`, make these declarations: ```julia using LearnDataFrontEnds -frontend = Tarragon() +const frontend = Tarragon() # both the following return objects with abstract type `Obs`: -LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend) LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend) +LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend) -# training data deconstructors: +# training data deconstructor: LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend) ``` diff --git a/src/LearnDataFrontEnds.jl b/src/LearnDataFrontEnds.jl index 34e589d..26a674b 100644 --- a/src/LearnDataFrontEnds.jl +++ b/src/LearnDataFrontEnds.jl @@ -4,7 +4,8 @@ Module providing the following commonly applicable data front ends for implementations of the [LearnAPI.jl](https://juliaai.github.io/LearnAPI.jl/dev/) interface: -- [`Saffron`](@ref): good for most supervised regressors operating on structured data +- [`Saffron`](@ref): good for most supervised leaners, typically regressors, operating on + structured data - [`Sage`](@ref): good for most supervised classifiers operating on structured data diff --git a/src/backends.jl b/src/backends.jl index a134fc3..2bb885b 100644 --- a/src/backends.jl +++ b/src/backends.jl @@ -22,12 +22,13 @@ Additionally, when `observations = fit(learner, data)` and the been implemented, one has: - `observations.target`: length `n` target vector (`multitarget=false`) or size `(q, n)` - target matrix (`multivariate=true`) + target matrix (`multivariate=true`); this array has the same element type as the + user-provided one in the `Saffron` case # Specific to `Sage` -If [`Sage`](@ref)`(multitarget=..., code_type=...)` has been implemented, the above -target representation has an integer element type controlled by `code_type`, and we +If [`Sage`](@ref)`(multitarget=..., code_type=...)` has been implemented, then +`observations.target` has an integer element type controlled by `code_type`, and we additionally have: - `observations.classes`: A categorical vector of the ordered target classes, as actually diff --git a/src/constants.jl b/src/constants.jl index a5f974e..807292f 100644 --- a/src/constants.jl +++ b/src/constants.jl @@ -11,3 +11,8 @@ const DOC_FORMULAS = """ const ERR_BAD_LEVELS = ArgumentError( "`levels` must be one of these: `:raw`, `:int`, `:small`. " ) + +const ERR_EXPECTED_CATEGORICAL = ArgumentError( + "Targets (or target columns of a table) must be `CategoricalArray`, "* + "or subarrays thereof. " +) diff --git a/src/saffron.jl b/src/saffron.jl index d7ccc61..716700f 100644 --- a/src/saffron.jl +++ b/src/saffron.jl @@ -1,9 +1,10 @@ """ Saffron(; multitarget=false, view=false) -A LearnAPI.jl data front end implemented for some supervised regressors consuming -structured data. If `learner` implements this front end, then `data` in the call -[`LearnAPI.fit`](@extref)`(learner, data)` can take any of the following forms: +A LearnAPI.jl data front end implemented for some supervised learners, typically +regressors, consuming structured data. If `learner` implements this front end, then `data` +in the call [`LearnAPI.fit`](@extref)`(learner, data)` can take any of the following +forms: - `(X, y)`, where `X` is a feature matrix or table and `y` is a target vector, matrix or table @@ -63,11 +64,11 @@ by making these declarations: ```julia using LearnDataFrontEnds -frontend = Saffron() # optionally specify `view=true` and/or `multitarget=true` +const frontend = Saffron() # optionally specify `view=true` and/or `multitarget=true` # both `obs` methods return objects of abstract type `Obs`: LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend) -LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend) +LearnAPI.obs(model::MyModel, data) = obs(model, data, frontend) # training data deconstructors: LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend) @@ -145,6 +146,10 @@ finalize(x, names, y, ::Saffron{<:Any,<:Any,IntCode}) = finalize(x, names, y, ::Saffron{<:Any,<:Any,SmallIntCode}) = finalize(x, names, y, CategoricalArrays.refcode) function finalize(x, names, y, int) # here `int` is `levelcode` or `refcode` function + y isa Union{ + CategoricalArrays.CategoricalArray, + SubArray{<:Any, <:Any, <:CategoricalArrays.CategoricalArray}, + } || throw(ERR_EXPECTED_CATEGORICAL) l = LearnDataFrontEnds.classes(y) u = unique(y) mask = map(in(u), l) diff --git a/src/sage.jl b/src/sage.jl index c94b2d7..21eebf3 100644 --- a/src/sage.jl +++ b/src/sage.jl @@ -73,17 +73,20 @@ training. # Implementation +If a core algorithm is happy to work with a `CategoricalArray` target, without +integer-encoding it, consider using the [`Saffron`](@ref) frontend instead. + For learners of type `MyLearner`, with `LearnAPI.fit(::MyLearner, data)` returning objects of type `MyModel`, implement the `Sage` data front by making these declarations: ```julia using LearnDataFrontEnds -frontend = Sage() # see above for options +const frontend = Sage() # see above for options # both `obs` methods return objects of abstract type `Obs`: LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend) -LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend) +LearnAPI.obs(model::MyModel, data) = obs(model, data, frontend) # training data deconstructors: LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend) diff --git a/src/tarragon.jl b/src/tarragon.jl index 79db1db..28ddf54 100644 --- a/src/tarragon.jl +++ b/src/tarragon.jl @@ -44,10 +44,10 @@ by making these declarations: ```julia using LearnDataFrontEnds -frontend = Tarragon() # optionally specify `view=true` +const frontend = Tarragon() # optionally specify `view=true` # both `obs` below return objects with abstract type `Obs`: -LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend) +LearnAPI.obs(model::MyModel, data) = obs(model, data, frontend) LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend) LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend) ``` diff --git a/test/saffron.jl b/test/saffron.jl index e01877c..fc9735d 100644 --- a/test/saffron.jl +++ b/test/saffron.jl @@ -10,6 +10,10 @@ import StatsModels.@formula # include("_some_learners.jl") +@testset "bad `Saffron` constructor argument" begin + @test_throws LearnDataFrontEnds.ERR_BAD_LEVELS Saffron(code_type=:junk) +end + n = 20 rng = StableRNG(345) c, t, a, t2 = randn(rng, n), rand(rng, n), rand(rng, n), rand(rng, n) diff --git a/test/sage.jl b/test/sage.jl index 9532511..16ff785 100644 --- a/test/sage.jl +++ b/test/sage.jl @@ -171,6 +171,13 @@ end @test CA.levels(yy) == CA.levels(y) end +@testset "informative error if fitobs supplied non-categorical target" begin + @test_throws( + LearnDataFrontEnds.ERR_EXPECTED_CATEGORICAL, + fitobs("learner", ((; x=fill(1, 3)), [1 ,2, 3]), Sage()) + ) +end + # from test/_some_learners.jl: learner = LearnerReportingNames() model = fit(learner, X)