Skip to content

Commit eed1f24

Browse files
committed
add some tests
1 parent 6e70c75 commit eed1f24

File tree

7 files changed

+33
-7
lines changed

7 files changed

+33
-7
lines changed

src/LearnDataFrontEnds.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Module providing the following commonly applicable data front ends for implementations of
55
the [LearnAPI.jl](https://juliaai.github.io/LearnAPI.jl/dev/) interface:
66
7-
- [`Saffron`](@ref): good for most supervised regressors operating on structured data
7+
- [`Saffron`](@ref): good for most supervised leaners, typically regressors, operating on
8+
structured data
89
910
- [`Sage`](@ref): good for most supervised classifiers operating on structured data
1011

src/backends.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ Additionally, when `observations = fit(learner, data)` and the
2222
been implemented, one has:
2323
2424
- `observations.target`: length `n` target vector (`multitarget=false`) or size `(q, n)`
25-
target matrix (`multivariate=true`)
25+
target matrix (`multivariate=true`); this array has the same element type as the
26+
user-provided one in the `Saffron` case
2627
2728
# Specific to `Sage`
2829
29-
If [`Sage`](@ref)`(multitarget=..., code_type=...)` has been implemented, the above
30-
target representation has an integer element type controlled by `code_type`, and we
30+
If [`Sage`](@ref)`(multitarget=..., code_type=...)` has been implemented, then
31+
`observations.target` has an integer element type controlled by `code_type`, and we
3132
additionally have:
3233
3334
- `observations.classes`: A categorical vector of the ordered target classes, as actually

src/constants.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,8 @@ const DOC_FORMULAS = """
1111
const ERR_BAD_LEVELS = ArgumentError(
1212
"`levels` must be one of these: `:raw`, `:int`, `:small`. "
1313
)
14+
15+
const ERR_EXPECTED_CATEGORICAL = ArgumentError(
16+
"Targets (or target columns of a table) must be `CategoricalArray`, "*
17+
"or subarrays thereof. "
18+
)

src/saffron.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""
22
Saffron(; multitarget=false, view=false)
33
4-
A LearnAPI.jl data front end implemented for some supervised regressors consuming
5-
structured data. If `learner` implements this front end, then `data` in the call
6-
[`LearnAPI.fit`](@extref)`(learner, data)` can take any of the following forms:
4+
A LearnAPI.jl data front end implemented for some supervised learners, typically
5+
regressors, consuming structured data. If `learner` implements this front end, then `data`
6+
in the call [`LearnAPI.fit`](@extref)`(learner, data)` can take any of the following
7+
forms:
78
89
- `(X, y)`, where `X` is a feature matrix or table and `y` is a target vector, matrix or
910
table
@@ -145,6 +146,10 @@ finalize(x, names, y, ::Saffron{<:Any,<:Any,IntCode}) =
145146
finalize(x, names, y, ::Saffron{<:Any,<:Any,SmallIntCode}) =
146147
finalize(x, names, y, CategoricalArrays.refcode)
147148
function finalize(x, names, y, int) # here `int` is `levelcode` or `refcode` function
149+
y isa Union{
150+
CategoricalArrays.CategoricalArray,
151+
SubArray{<:Any, <:Any, <:CategoricalArrays.CategoricalArray},
152+
} || throw(ERR_EXPECTED_CATEGORICAL)
148153
l = LearnDataFrontEnds.classes(y)
149154
u = unique(y)
150155
mask = map(in(u), l)

src/sage.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ training.
7373
7474
# Implementation
7575
76+
If a core algorithm is happy to work with a `CategoricalArray` target, without
77+
integer-encoding it, consider using the [`Saffron`](@ref) frontend instead.
78+
7679
For learners of type `MyLearner`, with `LearnAPI.fit(::MyLearner, data)` returning
7780
objects of type `MyModel`, implement the `Sage` data front
7881
by making these declarations:

test/saffron.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ import StatsModels.@formula
1010

1111
# include("_some_learners.jl")
1212

13+
@testset "bad `Saffron` constructor argument" begin
14+
@test_throws LearnDataFrontEnds.ERR_BAD_LEVELS Saffron(code_type=:junk)
15+
end
16+
1317
n = 20
1418
rng = StableRNG(345)
1519
c, t, a, t2 = randn(rng, n), rand(rng, n), rand(rng, n), rand(rng, n)

test/sage.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,13 @@ end
171171
@test CA.levels(yy) == CA.levels(y)
172172
end
173173

174+
@testset "informative error if fitobs supplied non-categorical target" begin
175+
@test_throws(
176+
LearnDataFrontEnds.ERR_EXPECTED_CATEGORICAL,
177+
fitobs("learner", ((; x=fill(1, 3)), [1 ,2, 3]), Sage())
178+
)
179+
end
180+
174181
# from test/_some_learners.jl:
175182
learner = LearnerReportingNames()
176183
model = fit(learner, X)

0 commit comments

Comments
 (0)