Skip to content

Commit c1cd479

Browse files
authored
Merge pull request #83 from alan-turing-institute/dev
Add correct trait names to utilities for declaring trait values
2 parents 485f839 + 62ba51f commit c1cd479

File tree

3 files changed

+73
-48
lines changed

3 files changed

+73
-48
lines changed

src/metadata_utils.jl

Lines changed: 71 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -25,37 +25,49 @@ a series of models.
2525
2626
## Keywords
2727
28-
* `name="unknown"` : package name
29-
* `uuid="unknown"` : package uuid
30-
* `url="unknown"` : package url
31-
* `julia=missing` : whether the package is pure julia
32-
* `license="unknown"`: package license
28+
* `package_name="unknown"` : package name
29+
* `package_uuid="unknown"` : package uuid
30+
* `package_url="unknown"` : package url
31+
* `is_pure_julia=missing` : whether the package is pure julia
32+
* `package_license="unknown"`: package license
3333
* `is_wrapper=false` : whether the package is a wrapper
3434
3535
## Example
3636
3737
```
3838
metadata_pkg.((KNNRegressor, KNNClassifier),
39-
name="NearestNeighbors",
40-
uuid="b8a86587-4115-5ab1-83bc-aa920d37bbce",
41-
url="https://github.com/KristofferC/NearestNeighbors.jl",
42-
julia=true,
43-
license="MIT",
39+
package_name="NearestNeighbors",
40+
package_uuid="b8a86587-4115-5ab1-83bc-aa920d37bbce",
41+
package_url="https://github.com/KristofferC/NearestNeighbors.jl",
42+
is_pure_julia=true,
43+
package_license="MIT",
4444
is_wrapper=false)
4545
```
4646
"""
47-
function metadata_pkg(T; name::String="unknown",
48-
uuid::String="unknown",
49-
url::String="unknown",
50-
julia::Union{Missing,Bool}=missing,
51-
license::String="unknown",
52-
is_wrapper::Bool=false)
47+
function metadata_pkg(T;
48+
49+
# aliases:
50+
name::String="unknown",
51+
uuid::String="unknown",
52+
url::String="unknown",
53+
julia::Union{Missing,Bool}=missing,
54+
license::String="unknown",
55+
is_wrapper::Bool=false,
56+
57+
# preferred names, corresponding to trait names:
58+
package_name=name,
59+
package_uuid=uuid,
60+
package_url=url,
61+
is_pure_julia=julia,
62+
package_license=license,
63+
64+
)
5365
ex = quote
54-
MLJModelInterface.package_name(::Type{<:$T}) = $name
55-
MLJModelInterface.package_uuid(::Type{<:$T}) = $uuid
56-
MLJModelInterface.package_url(::Type{<:$T}) = $url
57-
MLJModelInterface.is_pure_julia(::Type{<:$T}) = $julia
58-
MLJModelInterface.package_license(::Type{<:$T}) = $license
66+
MLJModelInterface.package_name(::Type{<:$T}) = $package_name
67+
MLJModelInterface.package_uuid(::Type{<:$T}) = $package_uuid
68+
MLJModelInterface.package_url(::Type{<:$T}) = $package_url
69+
MLJModelInterface.is_pure_julia(::Type{<:$T}) = $is_pure_julia
70+
MLJModelInterface.package_license(::Type{<:$T}) = $package_license
5971
MLJModelInterface.is_wrapper(::Type{<:$T}) = $is_wrapper
6072
end
6173
parentmodule(T).eval(ex)
@@ -68,44 +80,57 @@ Helper function to write the metadata for a model `T`.
6880
6981
## Keywords
7082
71-
* `input=Unknown` : allowed scientific type of the input data
72-
* `target=Unknown`: allowed sc. type of the target (supervised)
73-
* `output=Unknown`: allowed sc. type of the transformed data (unsupervised)
74-
* `weights=false` : whether the model supports sample weights
75-
* `descr=""` : short description of the model
76-
* `path=""` : where the model is (usually `PackageName.ModelName`)
83+
* `input_scitype=Unknown` : allowed scientific type of the input data
84+
* `target_scitype=Unknown`: allowed sc. type of the target (supervised)
85+
* `output_scitype=Unknown`: allowed sc. type of the transformed data (unsupervised)
86+
* `supports_weights=false` : whether the model supports sample weights
87+
* `docstring=""` : short description of the model
88+
* `load_path=""` : where the model is (usually `PackageName.ModelName`)
7789
7890
## Example
7991
8092
```
8193
metadata_model(KNNRegressor,
82-
input=MLJModelInterface.Table(MLJModelInterface.Continuous),
83-
target=AbstractVector{MLJModelInterface.Continuous},
84-
weights=true,
85-
descr="K-Nearest Neighbors classifier: ...",
86-
path="NearestNeighbors.KNNRegressor")
94+
input_scitype=MLJModelInterface.Table(MLJModelInterface.Continuous),
95+
target_scitype=AbstractVector{MLJModelInterface.Continuous},
96+
supports_weights=true,
97+
docstring="K-Nearest Neighbors classifier: ...",
98+
load_path="NearestNeighbors.KNNRegressor")
8799
```
88100
"""
89-
function metadata_model(T; input=Unknown,
90-
target=Unknown,
91-
output=Unknown,
92-
weights::Bool=false,
93-
descr::String="",
94-
path::String="")
95-
if isempty(path)
101+
function metadata_model(T;
102+
103+
# aliases:
104+
input=Unknown,
105+
target=Unknown,
106+
output=Unknown,
107+
weights::Bool=false,
108+
descr::String="",
109+
path::String="",
110+
111+
# preferred names, corresponding to trait names:
112+
input_scitype=input,
113+
target_scitype=target,
114+
output_scitype=output,
115+
supports_weights=weights,
116+
docstring=descr,
117+
load_path=path,
118+
119+
)
120+
if isempty(load_path)
96121
pname = MLJModelInterface.package_name(T)
97122
mname = MLJModelInterface.name(T)
98-
path = "MLJModels.$(pname)_.$(mname)"
123+
load_path = "MLJModels.$(pname)_.$(mname)"
99124
end
100125
ex = quote
101-
MLJModelInterface.input_scitype(::Type{<:$T}) = $input
102-
MLJModelInterface.output_scitype(::Type{<:$T}) = $output
103-
MLJModelInterface.target_scitype(::Type{<:$T}) = $target
104-
MLJModelInterface.supports_weights(::Type{<:$T}) = $weights
105-
MLJModelInterface.load_path(::Type{<:$T}) = $path
126+
MLJModelInterface.input_scitype(::Type{<:$T}) = $input_scitype
127+
MLJModelInterface.output_scitype(::Type{<:$T}) = $output_scitype
128+
MLJModelInterface.target_scitype(::Type{<:$T}) = $target_scitype
129+
MLJModelInterface.supports_weights(::Type{<:$T}) = $supports_weights
130+
MLJModelInterface.load_path(::Type{<:$T}) = $load_path
106131

107132
MLJModelInterface.docstring(::Type{<:$T}) =
108-
MLJModelInterface.docstring_ext($T; descr=$descr)
133+
MLJModelInterface.docstring_ext($T; descr=$docstring)
109134
end
110135
parentmodule(T).eval(ex)
111136
end

test/data_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ end
6363
setlight()
6464
ary = rand(10, 3)
6565
@test_throws M.InterfaceError M.schema(ary)
66-
df = DataFrame(rand(10, 3))
66+
df = DataFrame(rand(10, 3), :auto)
6767
@test_throws M.InterfaceError M.schema(df)
6868
end
6969
@testset "schema-full" begin

test/model_traits.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ struct Apple end
7373

7474
end
7575

76-
import .Banana
76+
import .Fruit
7777

7878
@testset "extras" begin
7979
@test docstring(Float64) == "Float64"

0 commit comments

Comments
 (0)