Skip to content

Commit e4fec71

Browse files
authored
Merge pull request #27 from JuliaAI/trait-smoke-tests
Add trait smoke tests
2 parents 44e97e3 + bcf9886 commit e4fec71

File tree

6 files changed

+72
-57
lines changed

6 files changed

+72
-57
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
with:
3030
version: ${{ matrix.version }}
3131
arch: ${{ matrix.arch }}
32-
- uses: actions/cache@v1
32+
- uses: julia-actions/cache@v1
3333
env:
3434
cache-name: cache-artifacts
3535
with:

.github/workflows/ci_nightly.yml

Lines changed: 0 additions & 50 deletions
This file was deleted.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJTestInterface"
22
uuid = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.2.8"
4+
version = "0.2.9"
55

66
[deps]
77
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"

src/attemptors.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,51 @@ function model_type(T, mod; throw=false, verbosity=1)
6565
return model_type, outcome
6666
end
6767

68+
# helpers:
69+
ismissing_or_isa(x, T) = ismissing(x) || x isa T
70+
bad_trait(model_type) = "$model_type has a bad trait declaration.\n"
71+
72+
const err_is_pure_julia(model_type) = ErrorException(
73+
bad_trait(model_type)*"`is_pure_julia` must return `true` or `false`. "
74+
)
75+
const err_supports_weights(model_type) = ErrorException(
76+
bad_trait(model_type)*"`supports_weights` must return `true`, `false` or `missing`. "
77+
)
78+
const err_supports_class_weights(model_type) = ErrorException(
79+
bad_trait(model_type)*"`supports__class_weights` must return `true`, `false` or `missing`. "
80+
)
81+
const err_is_wrapper(model_type) = ErrorException(
82+
bad_trait(model_type)*"`is_wrapper` must return `true` or `false`. "
83+
)
84+
const err_package_name(model_type) = ErrorException(
85+
bad_trait(model_type)*"`package_name` must return a `String`. "
86+
)
87+
const err_packge_license(model_type) = ErrorException(
88+
bad_trait(model_type)*"`package_license` must return a `String`. "
89+
)
90+
const err_iteration_parameter(model_type) = ErrorException(
91+
bad_trait(model_type)*"`iteration_parameter` must return a `Symbol` or `nothing`. "
92+
)
93+
94+
function traits(model_type; throw=false, verbosity=1)
95+
message = "[:traits] Apply smoke test to some model traits"
96+
attempt(finalize(message, verbosity); throw) do
97+
ismissing_or_isa(MLJBase.is_pure_julia(model_type), Bool) ||
98+
throw(err_is_pure_julia(model_type))
99+
ismissing_or_isa(MLJBase.supports_weights(model_type), Bool) ||
100+
throw(err_supports_(model_type))
101+
ismissing_or_isa(MLJBase.supports_class_weights(model_type), Bool) ||
102+
throw(err_supports_class_weights(model_type))
103+
MLJBase.package_name(model_type) isa String ||
104+
throw(err_package_name(model_type))
105+
MLJBase.package_license(model_type) isa String ||
106+
throw(err_package_license(model_type))
107+
MLJBase.iteration_parameter(model_type) isa Union{Nothing,Symbol} ||
108+
throw(err_iteration_parameter(model_type))
109+
nothing
110+
end
111+
end
112+
68113
function model_instance(model_type; throw=false, verbosity=1)
69114
message = "[:model_instance] Instantiating default model "
70115
attempt(finalize(message, verbosity); throw) do

src/test.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ loaded into the module `mod`.
6464
6565
The extent of testing is controlled by `level`:
6666
67-
|`level` | description | tests (full list below) |
68-
|:----------------|:-------------------------------|:------------------------|
69-
| 1 | test code loading | `:model_type` |
70-
| 2 (default) | basic test of model interface | all four tests |
67+
|`level` | description | tests (full list below) |
68+
|:----------------|:-------------------------------|:-------------------------|
69+
| 1 | test code loading | `:model_type`, `:traits` |
70+
| 2 (default) | basic test of model interface | all five tests |
7171
7272
For extensive MLJ integration tests, instead use `MLJTestIntegration.test`, from
7373
MLJTestIntegration.jl.
@@ -122,6 +122,8 @@ $DOC_LIST_OF_TESTS1
122122
- `:model_type`: Check `load_path` trait is correctly overloaded by attempting to
123123
re-import the type based on that trait's value.
124124
125+
- `:traits`: Apply smoke tests to model trait values.
126+
125127
$DOC_LIST_OF_TESTS2
126128
127129
See also [`MLJTestInterface.make_binary`](@ref),
@@ -144,16 +146,18 @@ function test(model_types, data...; mod=Main, level=2, throw=false, verbosity=1,
144146
:name,
145147
:package_name,
146148
:model_type,
149+
:traits,
147150
:model_instance,
148151
:fitted_machine,
149152
:operations,
150-
), NTuple{6, String}}}(undef, nmodels)
153+
), NTuple{7, String}}}(undef, nmodels)
151154

152155
# summary table row corresponding to all tests skipped:
153156
row0 = (
154157
; name="undefined",
155158
package_name= "undefined",
156159
model_type = "-",
160+
traits = "-",
157161
model_instance = "-",
158162
fitted_machine = "-",
159163
operations = "-",
@@ -190,6 +194,15 @@ function test(model_types, data...; mod=Main, level=2, throw=false, verbosity=1,
190194
row = update!(summary, failures, row, i, :model_type, model_type, outcome)
191195
outcome == "×" && continue
192196

197+
# [traits]:
198+
traits, outcome = MLJTestInterface.traits(
199+
model_type;
200+
throw,
201+
verbosity,
202+
)
203+
row = update!(summary, failures, row, i, :traits, traits, outcome)
204+
outcome == "×" && continue
205+
193206
level > 1 || continue
194207

195208
# [model_instance]:

test/test.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ expected_report1 = (
77
name = "ConstantClassifier",
88
package_name = "MLJModels",
99
model_type = "",
10+
traits = "",
1011
model_instance = "",
1112
fitted_machine = "",
1213
operations = "predict",
@@ -16,6 +17,7 @@ expected_report2 = (
1617
name = "DeterministicConstantClassifier",
1718
package_name = "MLJModels",
1819
model_type = "",
20+
traits = "",
1921
model_instance = "",
2022
fitted_machine = "",
2123
operations = "predict",
@@ -68,6 +70,7 @@ end
6870
name = "ConstantClassifier",
6971
package_name = "MLJModels",
7072
model_type = "",
73+
traits = "",
7174
model_instance = "",
7275
fitted_machine = "×",
7376
operations = "-",
@@ -77,6 +80,7 @@ end
7780
name = "DeterministicConstantClassifier",
7881
package_name = "MLJModels",
7982
model_type = "",
83+
traits = "",
8084
model_instance = "",
8185
fitted_machine = "",
8286
operations = "predict",
@@ -117,6 +121,7 @@ X, y = MLJTestInterface.make_binary()
117121
@test_logs(
118122
(:info, r"Testing ConstantClassifier"),
119123
(:info, r"model_type"),
124+
(:info, r"traits"),
120125
(:info, r"model_instance"),
121126
(:info, r"fitted_machine"),
122127
(:info, r"operations"),
@@ -145,6 +150,7 @@ end
145150
name = "ConstantClassifier",
146151
package_name = "MLJModels",
147152
model_type = "",
153+
traits = "",
148154
model_instance = "-",
149155
fitted_machine = "-",
150156
operations = "-",
@@ -164,6 +170,7 @@ end
164170
name = "ConstantClassifier",
165171
package_name = "MLJModels",
166172
model_type = "",
173+
traits = "",
167174
model_instance = "",
168175
fitted_machine = "",
169176
operations = "predict",

0 commit comments

Comments
 (0)