|
1 | 1 | """
|
2 | 2 | Saffron(; multitarget=false, view=false)
|
3 | 3 |
|
4 |
| -A LearnAPI.jl data front end implemented for some supervised regressors consuming |
5 |
| -structured data. If `learner` implements this front end, then `data` in the call |
6 |
| -[`LearnAPI.fit`](@extref)`(learner, data)` can take any of the following forms: |
| 4 | +A LearnAPI.jl data front end implemented for some supervised learners, typically |
| 5 | +regressors, consuming structured data. If `learner` implements this front end, then `data` |
| 6 | +in the call [`LearnAPI.fit`](@extref)`(learner, data)` can take any of the following |
| 7 | +forms: |
7 | 8 |
|
8 | 9 | - `(X, y)`, where `X` is a feature matrix or table and `y` is a target vector, matrix or
|
9 | 10 | table
|
@@ -63,11 +64,11 @@ by making these declarations:
|
63 | 64 |
|
64 | 65 | ```julia
|
65 | 66 | using LearnDataFrontEnds
|
66 |
| -frontend = Saffron() # optionally specify `view=true` and/or `multitarget=true` |
| 67 | +const frontend = Saffron() # optionally specify `view=true` and/or `multitarget=true` |
67 | 68 |
|
68 | 69 | # both `obs` methods return objects of abstract type `Obs`:
|
69 | 70 | LearnAPI.obs(learner::MyLearner, data) = fitobs(learner, data, frontend)
|
70 |
| -LearnAPI.obs(model::MyModel, X) = obs(model, data, frontend) |
| 71 | +LearnAPI.obs(model::MyModel, data) = obs(model, data, frontend) |
71 | 72 |
|
72 | 73 | # training data deconstructors:
|
73 | 74 | LearnAPI.features(learner::MyLearner, data) = LearnAPI.features(learner, data, frontend)
|
@@ -145,6 +146,10 @@ finalize(x, names, y, ::Saffron{<:Any,<:Any,IntCode}) =
|
145 | 146 | finalize(x, names, y, ::Saffron{<:Any,<:Any,SmallIntCode}) =
|
146 | 147 | finalize(x, names, y, CategoricalArrays.refcode)
|
147 | 148 | function finalize(x, names, y, int) # here `int` is `levelcode` or `refcode` function
|
| 149 | + y isa Union{ |
| 150 | + CategoricalArrays.CategoricalArray, |
| 151 | + SubArray{<:Any, <:Any, <:CategoricalArrays.CategoricalArray}, |
| 152 | + } || throw(ERR_EXPECTED_CATEGORICAL) |
148 | 153 | l = LearnDataFrontEnds.classes(y)
|
149 | 154 | u = unique(y)
|
150 | 155 | mask = map(in(u), l)
|
|
0 commit comments