Skip to content

Commit e571530

Browse files
authored
Merge pull request #29 from JuliaAI/dev
For a 0.2.9 release
2 parents 7bd76d2 + 6b61fde commit e571530

File tree

8 files changed

+88
-58
lines changed

8 files changed

+88
-58
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ jobs:
1818
matrix:
1919
version:
2020
- '1.6'
21+
- '1.10'
2122
- '1' # automatically expands to the latest stable 1.x release of Julia.
2223
os:
2324
- ubuntu-latest
@@ -29,7 +30,7 @@ jobs:
2930
with:
3031
version: ${{ matrix.version }}
3132
arch: ${{ matrix.arch }}
32-
- uses: actions/cache@v1
33+
- uses: julia-actions/cache@v1
3334
env:
3435
cache-name: cache-artifacts
3536
with:

.github/workflows/ci_nightly.yml

Lines changed: 0 additions & 50 deletions
This file was deleted.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJTestInterface"
22
uuid = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.2.8"
4+
version = "0.2.9"
55

66
[deps]
77
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ Package for testing an implementation of the
55

66
[![Lifecycle:Experimental](https://img.shields.io/badge/Lifecycle-Experimental-339999)](https://github.com/bcgov/repomountie/blob/master/doc/lifecycle-badges.md) [![Build Status](https://github.com/JuliaAI/MLJTestInterface.jl/workflows/CI/badge.svg)](https://github.com/JuliaAI/MLJTestInterface.jl/actions) [![Coverage](https://codecov.io/gh/JuliaAI/MLJTestInterface.jl/branch/master/graph/badge.svg)](https://codecov.io/github/JuliaAI/MLJTestInterface.jl?branch=master)
77

8+
For more extensive testing, see [MLJTestIntegration.jl](https://github.com/JuliaAI/MLJTestIntegration.jl/tree/dev).
9+
810
# Installation
911

1012
```julia

src/MLJTestInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ const N_MODELS_FOR_REPEATABILITY_TEST = 20
55
using MLJBase
66
using Pkg
77
using Test
8+
import MLJBase.CategoricalArrays.unwrap
89

910
include("attemptors.jl")
1011
include("test.jl")

src/attemptors.jl

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,51 @@ function model_type(T, mod; throw=false, verbosity=1)
6565
return model_type, outcome
6666
end
6767

68+
# helpers:
69+
ismissing_or_isa(x, T) = ismissing(x) || x isa T
70+
bad_trait(model_type) = "$model_type has a bad trait declaration.\n"
71+
72+
const err_is_pure_julia(model_type) = ErrorException(
73+
bad_trait(model_type)*"`is_pure_julia` must return `true` or `false`. "
74+
)
75+
const err_supports_weights(model_type) = ErrorException(
76+
bad_trait(model_type)*"`supports_weights` must return `true`, `false` or `missing`. "
77+
)
78+
const err_supports_class_weights(model_type) = ErrorException(
79+
bad_trait(model_type)*"`supports__class_weights` must return `true`, `false` or `missing`. "
80+
)
81+
const err_is_wrapper(model_type) = ErrorException(
82+
bad_trait(model_type)*"`is_wrapper` must return `true` or `false`. "
83+
)
84+
const err_package_name(model_type) = ErrorException(
85+
bad_trait(model_type)*"`package_name` must return a `String`. "
86+
)
87+
const err_packge_license(model_type) = ErrorException(
88+
bad_trait(model_type)*"`package_license` must return a `String`. "
89+
)
90+
const err_iteration_parameter(model_type) = ErrorException(
91+
bad_trait(model_type)*"`iteration_parameter` must return a `Symbol` or `nothing`. "
92+
)
93+
94+
function traits(model_type; throw=false, verbosity=1)
95+
message = "[:traits] Apply smoke test to some model traits"
96+
attempt(finalize(message, verbosity); throw) do
97+
ismissing_or_isa(MLJBase.is_pure_julia(model_type), Bool) ||
98+
throw(err_is_pure_julia(model_type))
99+
ismissing_or_isa(MLJBase.supports_weights(model_type), Bool) ||
100+
throw(err_supports_(model_type))
101+
ismissing_or_isa(MLJBase.supports_class_weights(model_type), Bool) ||
102+
throw(err_supports_class_weights(model_type))
103+
MLJBase.package_name(model_type) isa String ||
104+
throw(err_package_name(model_type))
105+
MLJBase.package_license(model_type) isa String ||
106+
throw(err_package_license(model_type))
107+
MLJBase.iteration_parameter(model_type) isa Union{Nothing,Symbol} ||
108+
throw(err_iteration_parameter(model_type))
109+
nothing
110+
end
111+
end
112+
68113
function model_instance(model_type; throw=false, verbosity=1)
69114
message = "[:model_instance] Instantiating default model "
70115
attempt(finalize(message, verbosity); throw) do
@@ -95,10 +140,21 @@ function operations(fitted_machine, data...; throw=false, verbosity=1)
95140
methods = MLJBase.implemented_methods(fitted_machine.model)
96141
_, test = MLJBase.partition(1:MLJBase.nrows(first(data)), 0.01)
97142
if :predict in methods
98-
predict(fitted_machine, first(data))
143+
yhat = predict(fitted_machine, first(data))
99144
model isa Static || predict(fitted_machine, rows=test)
100145
model isa Static || predict(fitted_machine, rows=:)
101146
push!(operations, "predict")
147+
148+
# check for double wrapped CategoricalValues in predict output for
149+
# classifiers:
150+
if target_scitype(model) <: AbstractVector{<:Finite} &&
151+
model isa Union{Deterministic,Probabilistic}
152+
η = model isa Deterministic ? first(yhat) : rand(first(yhat))
153+
unwrap(η) isa MLJBase.CategoricalArrays.CategoricalValue &&
154+
error("Doubly wrapped CategoricalValue encountered. Check use of "*
155+
"CategoricalArrays methods `levels` and `unique`, which changed in "*
156+
"version 1.0. ")
157+
end
102158
end
103159
if :transform in methods
104160
W = if model isa Static

src/test.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ loaded into the module `mod`.
6464
6565
The extent of testing is controlled by `level`:
6666
67-
|`level` | description | tests (full list below) |
68-
|:----------------|:-------------------------------|:------------------------|
69-
| 1 | test code loading | `:model_type` |
70-
| 2 (default) | basic test of model interface | all four tests |
67+
|`level` | description | tests (full list below) |
68+
|:----------------|:-------------------------------|:-------------------------|
69+
| 1 | test code loading | `:model_type`, `:traits` |
70+
| 2 (default) | basic test of model interface | all five tests |
7171
7272
For extensive MLJ integration tests, instead use `MLJTestIntegration.test`, from
7373
MLJTestIntegration.jl.
@@ -122,6 +122,8 @@ $DOC_LIST_OF_TESTS1
122122
- `:model_type`: Check `load_path` trait is correctly overloaded by attempting to
123123
re-import the type based on that trait's value.
124124
125+
- `:traits`: Apply smoke tests to model trait values.
126+
125127
$DOC_LIST_OF_TESTS2
126128
127129
See also [`MLJTestInterface.make_binary`](@ref),
@@ -144,16 +146,18 @@ function test(model_types, data...; mod=Main, level=2, throw=false, verbosity=1,
144146
:name,
145147
:package_name,
146148
:model_type,
149+
:traits,
147150
:model_instance,
148151
:fitted_machine,
149152
:operations,
150-
), NTuple{6, String}}}(undef, nmodels)
153+
), NTuple{7, String}}}(undef, nmodels)
151154

152155
# summary table row corresponding to all tests skipped:
153156
row0 = (
154157
; name="undefined",
155158
package_name= "undefined",
156159
model_type = "-",
160+
traits = "-",
157161
model_instance = "-",
158162
fitted_machine = "-",
159163
operations = "-",
@@ -190,6 +194,15 @@ function test(model_types, data...; mod=Main, level=2, throw=false, verbosity=1,
190194
row = update!(summary, failures, row, i, :model_type, model_type, outcome)
191195
outcome == "×" && continue
192196

197+
# [traits]:
198+
traits, outcome = MLJTestInterface.traits(
199+
model_type;
200+
throw,
201+
verbosity,
202+
)
203+
row = update!(summary, failures, row, i, :traits, traits, outcome)
204+
outcome == "×" && continue
205+
193206
level > 1 || continue
194207

195208
# [model_instance]:

test/test.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ expected_report1 = (
77
name = "ConstantClassifier",
88
package_name = "MLJModels",
99
model_type = "",
10+
traits = "",
1011
model_instance = "",
1112
fitted_machine = "",
1213
operations = "predict",
@@ -16,6 +17,7 @@ expected_report2 = (
1617
name = "DeterministicConstantClassifier",
1718
package_name = "MLJModels",
1819
model_type = "",
20+
traits = "",
1921
model_instance = "",
2022
fitted_machine = "",
2123
operations = "predict",
@@ -68,6 +70,7 @@ end
6870
name = "ConstantClassifier",
6971
package_name = "MLJModels",
7072
model_type = "",
73+
traits = "",
7174
model_instance = "",
7275
fitted_machine = "×",
7376
operations = "-",
@@ -77,6 +80,7 @@ end
7780
name = "DeterministicConstantClassifier",
7881
package_name = "MLJModels",
7982
model_type = "",
83+
traits = "",
8084
model_instance = "",
8185
fitted_machine = "",
8286
operations = "predict",
@@ -117,6 +121,7 @@ X, y = MLJTestInterface.make_binary()
117121
@test_logs(
118122
(:info, r"Testing ConstantClassifier"),
119123
(:info, r"model_type"),
124+
(:info, r"traits"),
120125
(:info, r"model_instance"),
121126
(:info, r"fitted_machine"),
122127
(:info, r"operations"),
@@ -145,6 +150,7 @@ end
145150
name = "ConstantClassifier",
146151
package_name = "MLJModels",
147152
model_type = "",
153+
traits = "",
148154
model_instance = "-",
149155
fitted_machine = "-",
150156
operations = "-",
@@ -164,6 +170,7 @@ end
164170
name = "ConstantClassifier",
165171
package_name = "MLJModels",
166172
model_type = "",
173+
traits = "",
167174
model_instance = "",
168175
fitted_machine = "",
169176
operations = "predict",

0 commit comments

Comments
 (0)