Skip to content

Commit f656064

Browse files
authored
Merge pull request #3 from JuliaAI/boost-coverage
Add some tests to boost coverage
2 parents 6e70c75 + 8a18815 commit f656064

File tree

9 files changed

+46
-20
lines changed

9 files changed

+46
-20
lines changed

docs/src/quick_start.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ returning objects of type `MyModel`, make these declarations:
1616

1717
```julia
1818
using LearnDataFrontEnds
19-
frontend = Saffron()
19+
const frontend = Saffron()
2020

2121
# both methods below return objects with abstract type `Obs`:
2222
LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend)
23-
LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend)
23+
LearnAPI.obs(model::MyModel, data) = obs(model, data, frontend)
2424

2525
# training data deconstructors:
2626
LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend)
@@ -72,11 +72,11 @@ returning objects of type `MyModel`, make these declarations:
7272

7373
```julia
7474
using LearnDataFrontEnds
75-
frontend = Sage()
75+
const frontend = Sage()
7676

7777
# both methods below return objects with abstract type `Obs`:
7878
LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend)
79-
LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend)
79+
LearnAPI.obs(model::MyModel, data) = obs(model, data, frontend)
8080

8181
# training data deconstructors:
8282
LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend)
@@ -138,13 +138,13 @@ objects of type `MyModel`, make these declarations:
138138

139139
```julia
140140
using LearnDataFrontEnds
141-
frontend = Tarragon()
141+
const frontend = Tarragon()
142142

143143
# both the following return objects with abstract type `Obs`:
144-
LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend)
145144
LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend)
145+
LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend)
146146

147-
# training data deconstructors:
147+
# training data deconstructor:
148148
LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend)
149149
```
150150

src/LearnDataFrontEnds.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
Module providing the following commonly applicable data front ends for implementations of
55
the [LearnAPI.jl](https://juliaai.github.io/LearnAPI.jl/dev/) interface:
66
7-
- [`Saffron`](@ref): good for most supervised regressors operating on structured data
7+
- [`Saffron`](@ref): good for most supervised leaners, typically regressors, operating on
8+
structured data
89
910
- [`Sage`](@ref): good for most supervised classifiers operating on structured data
1011

src/backends.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ Additionally, when `observations = fit(learner, data)` and the
2222
been implemented, one has:
2323
2424
- `observations.target`: length `n` target vector (`multitarget=false`) or size `(q, n)`
25-
target matrix (`multivariate=true`)
25+
target matrix (`multivariate=true`); this array has the same element type as the
26+
user-provided one in the `Saffron` case
2627
2728
# Specific to `Sage`
2829
29-
If [`Sage`](@ref)`(multitarget=..., code_type=...)` has been implemented, the above
30-
target representation has an integer element type controlled by `code_type`, and we
30+
If [`Sage`](@ref)`(multitarget=..., code_type=...)` has been implemented, then
31+
`observations.target` has an integer element type controlled by `code_type`, and we
3132
additionally have:
3233
3334
- `observations.classes`: A categorical vector of the ordered target classes, as actually

src/constants.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,8 @@ const DOC_FORMULAS = """
1111
const ERR_BAD_LEVELS = ArgumentError(
1212
"`levels` must be one of these: `:raw`, `:int`, `:small`. "
1313
)
14+
15+
const ERR_EXPECTED_CATEGORICAL = ArgumentError(
16+
"Targets (or target columns of a table) must be `CategoricalArray`, "*
17+
"or subarrays thereof. "
18+
)

src/saffron.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""
22
Saffron(; multitarget=false, view=false)
33
4-
A LearnAPI.jl data front end implemented for some supervised regressors consuming
5-
structured data. If `learner` implements this front end, then `data` in the call
6-
[`LearnAPI.fit`](@extref)`(learner, data)` can take any of the following forms:
4+
A LearnAPI.jl data front end implemented for some supervised learners, typically
5+
regressors, consuming structured data. If `learner` implements this front end, then `data`
6+
in the call [`LearnAPI.fit`](@extref)`(learner, data)` can take any of the following
7+
forms:
78
89
- `(X, y)`, where `X` is a feature matrix or table and `y` is a target vector, matrix or
910
table
@@ -63,11 +64,11 @@ by making these declarations:
6364
6465
```julia
6566
using LearnDataFrontEnds
66-
frontend = Saffron() # optionally specify `view=true` and/or `multitarget=true`
67+
const frontend = Saffron() # optionally specify `view=true` and/or `multitarget=true`
6768
6869
# both `obs` methods return objects of abstract type `Obs`:
6970
LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend)
70-
LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend)
71+
LearnAPI.obs(model::MyModel, data) = obs(model, data, frontend)
7172
7273
# training data deconstructors:
7374
LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend)
@@ -145,6 +146,10 @@ finalize(x, names, y, ::Saffron{<:Any,<:Any,IntCode}) =
145146
finalize(x, names, y, ::Saffron{<:Any,<:Any,SmallIntCode}) =
146147
finalize(x, names, y, CategoricalArrays.refcode)
147148
function finalize(x, names, y, int) # here `int` is `levelcode` or `refcode` function
149+
y isa Union{
150+
CategoricalArrays.CategoricalArray,
151+
SubArray{<:Any, <:Any, <:CategoricalArrays.CategoricalArray},
152+
} || throw(ERR_EXPECTED_CATEGORICAL)
148153
l = LearnDataFrontEnds.classes(y)
149154
u = unique(y)
150155
mask = map(in(u), l)

src/sage.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,20 @@ training.
7373
7474
# Implementation
7575
76+
If a core algorithm is happy to work with a `CategoricalArray` target, without
77+
integer-encoding it, consider using the [`Saffron`](@ref) frontend instead.
78+
7679
For learners of type `MyLearner`, with `LearnAPI.fit(::MyLearner, data)` returning
7780
objects of type `MyModel`, implement the `Sage` data front
7881
by making these declarations:
7982
8083
```julia
8184
using LearnDataFrontEnds
82-
frontend = Sage() # see above for options
85+
const frontend = Sage() # see above for options
8386
8487
# both `obs` methods return objects of abstract type `Obs`:
8588
LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend)
86-
LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend)
89+
LearnAPI.obs(model::MyModel, data) = obs(model, data, frontend)
8790
8891
# training data deconstructors:
8992
LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend)

src/tarragon.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ by making these declarations:
4444
4545
```julia
4646
using LearnDataFrontEnds
47-
frontend = Tarragon() # optionally specify `view=true`
47+
const frontend = Tarragon() # optionally specify `view=true`
4848
4949
# both `obs` below return objects with abstract type `Obs`:
50-
LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend)
50+
LearnAPI.obs(model::MyModel, data) = obs(model, data, frontend)
5151
LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend)
5252
LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend)
5353
```

test/saffron.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ import StatsModels.@formula
1010

1111
# include("_some_learners.jl")
1212

13+
@testset "bad `Saffron` constructor argument" begin
14+
@test_throws LearnDataFrontEnds.ERR_BAD_LEVELS Saffron(code_type=:junk)
15+
end
16+
1317
n = 20
1418
rng = StableRNG(345)
1519
c, t, a, t2 = randn(rng, n), rand(rng, n), rand(rng, n), rand(rng, n)

test/sage.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,13 @@ end
171171
@test CA.levels(yy) == CA.levels(y)
172172
end
173173

174+
@testset "informative error if fitobs supplied non-categorical target" begin
175+
@test_throws(
176+
LearnDataFrontEnds.ERR_EXPECTED_CATEGORICAL,
177+
fitobs("learner", ((; x=fill(1, 3)), [1 ,2, 3]), Sage())
178+
)
179+
end
180+
174181
# from test/_some_learners.jl:
175182
learner = LearnerReportingNames()
176183
model = fit(learner, X)

0 commit comments

Comments
 (0)