Skip to content

Commit 6b61fde

Browse files
authored
Merge pull request #28 from JuliaAI/test-classifier-predictions
Test classifier predictions for doubly wrapped `CategoricalValue`
2 parents e4fec71 + 14f3193 commit 6b61fde

File tree

3 files changed

+14
-1
lines changed

3 files changed

+14
-1
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ jobs:
1818
matrix:
1919
version:
2020
- '1.6'
21+
- '1.10'
2122
- '1' # automatically expands to the latest stable 1.x release of Julia.
2223
os:
2324
- ubuntu-latest

src/MLJTestInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ const N_MODELS_FOR_REPEATABILITY_TEST = 20
55
using MLJBase
66
using Pkg
77
using Test
8+
import MLJBase.CategoricalArrays.unwrap
89

910
include("attemptors.jl")
1011
include("test.jl")

src/attemptors.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,21 @@ function operations(fitted_machine, data...; throw=false, verbosity=1)
140140
methods = MLJBase.implemented_methods(fitted_machine.model)
141141
_, test = MLJBase.partition(1:MLJBase.nrows(first(data)), 0.01)
142142
if :predict in methods
143-
predict(fitted_machine, first(data))
143+
yhat = predict(fitted_machine, first(data))
144144
model isa Static || predict(fitted_machine, rows=test)
145145
model isa Static || predict(fitted_machine, rows=:)
146146
push!(operations, "predict")
147+
148+
# check for double wrapped CategoricalValues in predict output for
149+
# classifiers:
150+
if target_scitype(model) <: AbstractVector{<:Finite} &&
151+
model isa Union{Deterministic,Probabilistic}
152+
η = model isa Deterministic ? first(yhat) : rand(first(yhat))
153+
unwrap(η) isa MLJBase.CategoricalArrays.CategoricalValue &&
154+
error("Doubly wrapped CategoricalValue encountered. Check use of "*
155+
"CategoricalArrays methods `levels` and `unique`, which changed in "*
156+
"version 1.0. ")
157+
end
147158
end
148159
if :transform in methods
149160
W = if model isa Static

0 commit comments

Comments
 (0)