Skip to content

Commit 1766340

Browse files
committed
add supports_class_weights as kwarg for metadata_model(; kwargs...)
1 parent f66fb59 commit 1766340

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

src/metadata_utils.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ Helper function to write the metadata for a model `T`.
7878
* `target_scitype=Unknown`: allowed scitype of the target (supervised)
7979
* `output_scitype=Unkonwn`: allowed scitype of the transformed data (unsupervised)
8080
* `supports_weights=false`: whether the model supports sample weights
81+
* `supports_class_weights=false`: whether the model supports class weights
8182
* `load_path="unknown"`: where the model is (usually `PackageName.ModelName`)
8283
8384
## Example
@@ -97,6 +98,7 @@ function metadata_model(
9798
target=nothing,
9899
output=nothing,
99100
weights::Union{Nothing,Bool}=nothing,
101+
class_weights::Union{Nothing,Bool}=nothing,
100102
descr::Union{Nothing,String}=nothing,
101103
path::Union{Nothing,String}=nothing,
102104

@@ -105,9 +107,11 @@ function metadata_model(
105107
target_scitype=target,
106108
output_scitype=output,
107109
supports_weights::Union{Nothing,Bool}=weights,
110+
supports_class_weights::Union{Nothing,Bool}=weights,
108111
docstring::Union{Nothing,String}=descr,
109112
load_path::Union{Nothing,String}=path,
110113
)
114+
111115
load_path === nothing && @warn WARN_MISSING_LOAD_PATH
112116

113117
program = quote end
@@ -119,6 +123,7 @@ function metadata_model(
119123
_extend!(program, :target_scitype, target_scitype, T)
120124
_extend!(program, :output_scitype, output_scitype, T)
121125
_extend!(program, :supports_weights, supports_weights, T)
126+
_extend!(program, :supports_class_weights,supports_class_weights, T)
122127
_extend!(program, :docstring, docstring, T)
123128
_extend!(program, :load_path, load_path, T)
124129

test/metadata_utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ M.implemented_methods(::FI, M::Type{<:MLJType}) =
2929
metadata_model(FooRegressor,
3030
input_scitype=Table(Continuous),
3131
target_scitype=AbstractVector{Continuous},
32+
supports_class_weights=true,
3233
load_path="goo goo")
3334

3435
infos = Dict(trait => eval(:(MLJModelInterface.$trait))(FooRegressor) for
@@ -46,6 +47,7 @@ infos = Dict(trait => eval(:(MLJModelInterface.$trait))(FooRegressor) for
4647
@test infos[:package_url] == "http://existentialcomics.com/"
4748
@test !infos[:is_wrapper]
4849
@test !infos[:supports_weights]
50+
@test infos[:supports_class_weights]
4951
@test !infos[:supports_online]
5052
@test infos[:docstring] == "Cool model\n"
5153
@test infos[:name] == "FooRegressor"

0 commit comments

Comments
 (0)