Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docs/src/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ returning objects of type `MyModel`, make these declarations:

```julia
using LearnDataFrontEnds
frontend = Saffron()
const frontend = Saffron()

# both methods below return objects with abstract type `Obs`:
LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend)
LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend)
LearnAPI.obs(model::MyModel, data) = obs(model, data, frontend)

# training data deconstructors:
LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend)
Expand Down Expand Up @@ -72,11 +72,11 @@ returning objects of type `MyModel`, make these declarations:

```julia
using LearnDataFrontEnds
frontend = Sage()
const frontend = Sage()

# both methods below return objects with abstract type `Obs`:
LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend)
LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend)
LearnAPI.obs(model::MyModel, data) = obs(model, data, frontend)

# training data deconstructors:
LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend)
Expand Down Expand Up @@ -138,13 +138,13 @@ objects of type `MyModel`, make these declarations:

```julia
using LearnDataFrontEnds
frontend = Tarragon()
const frontend = Tarragon()

# both the following return objects with abstract type `Obs`:
LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend)
LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend)
LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend)

# training data deconstructors:
# training data deconstructor:
LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend)
```

Expand Down
3 changes: 2 additions & 1 deletion src/LearnDataFrontEnds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
Module providing the following commonly applicable data front ends for implementations of
the [LearnAPI.jl](https://juliaai.github.io/LearnAPI.jl/dev/) interface:

- [`Saffron`](@ref): good for most supervised regressors operating on structured data
- [`Saffron`](@ref): good for most supervised leaners, typically regressors, operating on
structured data

- [`Sage`](@ref): good for most supervised classifiers operating on structured data

Expand Down
7 changes: 4 additions & 3 deletions src/backends.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ Additionally, when `observations = fit(learner, data)` and the
been implemented, one has:

- `observations.target`: length `n` target vector (`multitarget=false`) or size `(q, n)`
target matrix (`multivariate=true`)
target matrix (`multivariate=true`); this array has the same element type as the
user-provided one in the `Saffron` case

# Specific to `Sage`

If [`Sage`](@ref)`(multitarget=..., code_type=...)` has been implemented, the above
target representation has an integer element type controlled by `code_type`, and we
If [`Sage`](@ref)`(multitarget=..., code_type=...)` has been implemented, then
`observations.target` has an integer element type controlled by `code_type`, and we
additionally have:

- `observations.classes`: A categorical vector of the ordered target classes, as actually
Expand Down
5 changes: 5 additions & 0 deletions src/constants.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@ const DOC_FORMULAS = """
const ERR_BAD_LEVELS = ArgumentError(
"`levels` must be one of these: `:raw`, `:int`, `:small`. "
)

const ERR_EXPECTED_CATEGORICAL = ArgumentError(
"Targets (or target columns of a table) must be `CategoricalArray`, "*
"or subarrays thereof. "
)
15 changes: 10 additions & 5 deletions src/saffron.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""
Saffron(; multitarget=false, view=false)

A LearnAPI.jl data front end implemented for some supervised regressors consuming
structured data. If `learner` implements this front end, then `data` in the call
[`LearnAPI.fit`](@extref)`(learner, data)` can take any of the following forms:
A LearnAPI.jl data front end implemented for some supervised learners, typically
regressors, consuming structured data. If `learner` implements this front end, then `data`
in the call [`LearnAPI.fit`](@extref)`(learner, data)` can take any of the following
forms:

- `(X, y)`, where `X` is a feature matrix or table and `y` is a target vector, matrix or
table
Expand Down Expand Up @@ -63,11 +64,11 @@ by making these declarations:

```julia
using LearnDataFrontEnds
frontend = Saffron() # optionally specify `view=true` and/or `multitarget=true`
const frontend = Saffron() # optionally specify `view=true` and/or `multitarget=true`

# both `obs` methods return objects of abstract type `Obs`:
LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend)
LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend)
LearnAPI.obs(model::MyModel, data) = obs(model, data, frontend)

# training data deconstructors:
LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend)
Expand Down Expand Up @@ -145,6 +146,10 @@ finalize(x, names, y, ::Saffron{<:Any,<:Any,IntCode}) =
finalize(x, names, y, ::Saffron{<:Any,<:Any,SmallIntCode}) =
finalize(x, names, y, CategoricalArrays.refcode)
function finalize(x, names, y, int) # here `int` is `levelcode` or `refcode` function
y isa Union{
CategoricalArrays.CategoricalArray,
SubArray{<:Any, <:Any, <:CategoricalArrays.CategoricalArray},
} || throw(ERR_EXPECTED_CATEGORICAL)
l = LearnDataFrontEnds.classes(y)
u = unique(y)
mask = map(in(u), l)
Expand Down
7 changes: 5 additions & 2 deletions src/sage.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,20 @@ training.

# Implementation

If a core algorithm is happy to work with a `CategoricalArray` target, without
integer-encoding it, consider using the [`Saffron`](@ref) frontend instead.

For learners of type `MyLearner`, with `LearnAPI.fit(::MyLearner, data)` returning
objects of type `MyModel`, implement the `Sage` data front
by making these declarations:

```julia
using LearnDataFrontEnds
frontend = Sage() # see above for options
const frontend = Sage() # see above for options

# both `obs` methods return objects of abstract type `Obs`:
LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend)
LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend)
LearnAPI.obs(model::MyModel, data) = obs(model, data, frontend)

# training data deconstructors:
LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend)
Expand Down
4 changes: 2 additions & 2 deletions src/tarragon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ by making these declarations:

```julia
using LearnDataFrontEnds
frontend = Tarragon() # optionally specify `view=true`
const frontend = Tarragon() # optionally specify `view=true`

# both `obs` below return objects with abstract type `Obs`:
LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend)
LearnAPI.obs(model::MyModel, data) = obs(model, data, frontend)
LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend)
LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend)
```
Expand Down
4 changes: 4 additions & 0 deletions test/saffron.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import StatsModels.@formula

# include("_some_learners.jl")

@testset "bad `Saffron` constructor argument" begin
@test_throws LearnDataFrontEnds.ERR_BAD_LEVELS Saffron(code_type=:junk)
end

n = 20
rng = StableRNG(345)
c, t, a, t2 = randn(rng, n), rand(rng, n), rand(rng, n), rand(rng, n)
Expand Down
7 changes: 7 additions & 0 deletions test/sage.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,13 @@ end
@test CA.levels(yy) == CA.levels(y)
end

@testset "informative error if fitobs supplied non-categorical target" begin
@test_throws(
LearnDataFrontEnds.ERR_EXPECTED_CATEGORICAL,
fitobs("learner", ((; x=fill(1, 3)), [1 ,2, 3]), Sage())
)
end

# from test/_some_learners.jl:
learner = LearnerReportingNames()
model = fit(learner, X)
Expand Down
Loading