Skip to content

Commit 5886633

Browse files
ablaomOkonSamuel
andauthored
For a 0.5.1 release (#156)
* add `feature_importances` stub (#148) * add intrinsic_importances stub and set fallback to othing. * fix error in intrinsic_importances docstring. * rename intrinsic_importances method to eature_importances. * remove fallback for eature_importances. * Update src/model_api.jl Co-authored-by: Anthony Blaom, PhD <[email protected]> * bump 1.5 * Update `metadata_model` to include traits for feature importances and training losses (#155) * + supports_training_losses,reports_feature_importances to model_metadata * bump 1.5.1 * bump StatisticalTraits = "3.1" Co-authored-by: Okon Samuel <[email protected]>
1 parent e8da6ba commit 5886633

File tree

5 files changed

+42
-20
lines changed

5 files changed

+42
-20
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "1.5"
4+
version = "1.5.1"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -10,7 +10,7 @@ StatisticalTraits = "64bff920-2084-43da-a3e6-9bb72801c0c9"
1010

1111
[compat]
1212
ScientificTypesBase = "3.0"
13-
StatisticalTraits = "3.0"
13+
StatisticalTraits = "3.1"
1414
julia = "1"
1515

1616
[extras]

src/metadata_utils.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ Helper function to write the metadata for a model `T`.
8585
* `supports_class_weights=false`: whether the model supports class weights
8686
* `load_path="unknown"`: where the model is (usually `PackageName.ModelName`)
8787
* `human_name=nothing`: human name of the model
88+
* `supports_training_losses=nothing`: whether the (necessarily iterative) model can report
89+
training losses
90+
* `reports_feature_importances=nothing`: whether the model reports feature importances
8891
8992
## Example
9093
@@ -115,7 +118,9 @@ function metadata_model(
115118
supports_class_weights::Union{Nothing,Bool}=class_weights,
116119
docstring::Union{Nothing,String}=descr,
117120
load_path::Union{Nothing,String}=path,
118-
human_name::Union{Nothing,String}=nothing
121+
human_name::Union{Nothing,String}=nothing,
122+
supports_training_losses::Union{Nothing,Bool}=nothing,
123+
reports_feature_importances::Union{Nothing,Bool}=nothing,
119124
)
120125
docstring === nothing || Base.depwarn(DEPWARN_DOCSTRING, :metadata_model)
121126

@@ -132,6 +137,8 @@ function metadata_model(
132137
_extend!(program, :docstring, docstring, T)
133138
_extend!(program, :load_path, load_path, T)
134139
_extend!(program, :human_name, human_name, T)
140+
_extend!(program, :supports_training_losses, supports_training_losses, T)
141+
_extend!(program, :reports_feature_importances, reports_feature_importances, T)
135142

136143
parentmodule(T).eval(program)
137144
end

src/model_api.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ in historical order. If the model calculates scores instead, then the
4141
sign of the scores should be reversed.
4242
4343
The following trait overload is also required:
44-
`supports_training_losses(::Type{<:M}) = true`
44+
`MLJModelInterface.supports_training_losses(::Type{<:M}) = true`.
4545
4646
"""
4747
training_losses(model, report) = nothing
@@ -168,16 +168,17 @@ function evaluate end
168168
"""
169169
feature_importances(model::M, fitresult, report)
170170
171-
For a given `model` of model type `M` supporting intrinsic feature importances, calculate
172-
the feature importances from the model's `fitresult` and `report` as an
173-
abstract vector of `feature::Symbol => importance::Real` pairs
174-
(e.g `[:gender =>0.23, :height =>0.7, :weight => 0.1]`).
171+
For a given `model` of model type `M` supporting intrinsic feature importances, calculate
172+
the feature importances from the model's `fitresult` and `report` as an
173+
abstract vector of `feature::Symbol => importance::Real` pairs
174+
(e.g `[:gender =>0.23, :height =>0.7, :weight => 0.1]`).
175175
176176
The following trait overload is also required:
177-
`reports_feature_importances(::Type{<:M}) = true`
177+
`MLJModelInterface.reports_feature_importances(::Type{<:M}) = true`
178178
179179
If for some reason a model is sometimes unable to report feature importances then
180-
`feature_importances` should return all importances as 0.0, as in
181-
`[:gender =>0.0, :height =>0.0, :weight => 0.0]`.
180+
`feature_importances` should return all importances as 0.0, as in
181+
`[:gender =>0.0, :height =>0.0, :weight => 0.0]`.
182+
182183
"""
183184
function feature_importances end

test/metadata_utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ infos = Dict(trait => eval(:(MLJModelInterface.$trait))(FooRegressor) for
142142
@test infos[:hyperparameters] == (:a, :b)
143143
@test infos[:hyperparameter_types] == ("Int64", "Any")
144144
@test infos[:hyperparameter_ranges] == (nothing, nothing)
145+
@test !infos[:supports_training_losses]
146+
@test !infos[:reports_feature_importances]
145147
end
146148

147149
@testset "doc_header(ModelType)" begin

test/model_api.jl

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,22 @@ mutable struct APIx1 <: Static end
1414
@test selectrows(APIx0(), 2:3, X, y) == ((x1 = [4, 6],), [20.0, 30.0])
1515
end
1616

17+
M.metadata_model(
18+
APIx0,
19+
supports_training_losses = true,
20+
reports_feature_importances = true,
21+
)
22+
23+
dummy_losses = [1.0, 2.0, 3.0]
24+
M.training_losses(::APIx0, report) = report
25+
M.feature_importances(::APIx0, fitresult, report) = [:a=>0, :b=>0]
26+
1727
@testset "fit-x" begin
1828
m0 = APIx0(f0=1)
1929
m1 = APIx0b(f0=3)
2030
# no weight support: fallback
21-
M.fit(m::APIx0, v::Int, X, y) = (5, nothing, nothing)
22-
@test fit(m0, 1, randn(2), randn(2), 5) == (5, nothing, nothing)
31+
M.fit(m::APIx0, v::Int, X, y) = (5, nothing, dummy_losses)
32+
@test fit(m0, 1, randn(2), randn(2), 5) == (5, nothing, dummy_losses)
2333
# with weight support: use
2434
M.fit(m::APIx0b, v::Int, X, y, w) = (7, nothing, nothing)
2535
@test fit(m1, 1, randn(2), randn(2), 5) == (7, nothing, nothing)
@@ -32,16 +42,18 @@ end
3242
@test fit(s1, 1, 0) == (nothing, nothing, nothing)
3343

3444
# update fallback = fit
35-
@test update(m0, 1, 5, nothing, randn(2), 5) == (5, nothing, nothing)
45+
@test update(m0, 1, 5, nothing, randn(2), 5) == (5, nothing, dummy_losses)
3646

3747
# training losses:
3848
f, c, r = MLJModelInterface.fit(m0, 1, rand(2), rand(2))
39-
@test M.training_losses(m0, r) === nothing
40-
41-
# intrinsic_importances
49+
@test M.training_losses(m0, r) == dummy_losses
50+
51+
# training losses:
52+
f, c, r = MLJModelInterface.fit(m0, 1, rand(2), rand(2))
53+
@test M.training_losses(m0, r) == dummy_losses
54+
55+
# feature_importances
4256
f, c, r = MLJModelInterface.fit(m0, 1, rand(2), rand(2))
43-
MLJModelInterface.reports_feature_importances(::Type{APIx0}) = true
44-
MLJModelInterface.feature_importances(::APIx0, fitresult, report) = [:a=>0, :b=>0]
4557
@test MLJModelInterface.feature_importances(m0, f, r) == [:a=>0, :b=>0]
4658
end
4759

@@ -67,7 +79,7 @@ mutable struct UnivariateFiniteFitter <: Probabilistic end
6779
end
6880

6981
MMI.input_scitype(::Type{<:UnivariateFiniteFitter}) = Nothing
70-
82+
7183
MMI.target_scitype(::Type{<:UnivariateFiniteFitter}) = AbstractVector{<:Finite}
7284

7385
y = categorical(collect("aabbccaa"))

0 commit comments

Comments
 (0)