Skip to content

Commit e6dd9eb

Browse files
committed
add metadata(pkg); WIP
1 parent 598f249 commit e6dd9eb

File tree

4 files changed

+194
-0
lines changed

4 files changed

+194
-0
lines changed

src/MLJModelRegistry.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@ const REGISTRY = joinpath(ROOT, "registry")
1313
include("GenericRegistry.jl")
1414
include("check_traits.jl")
1515
include("remote_methods.jl")
16+
include("methods.jl")
1617

1718
end # module

src/check_traits.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# TODO: This method should live at MLJModelInterface !
2+
3+
ismissing_or_isa(x, T) = ismissing(x) || x isa T
4+
5+
function check_traits(M)
6+
message = "$M has a bad trait declaration.\n"
7+
ismissing_or_isa(MLJModelInterface.is_pure_julia(M), Bool) ||
8+
error(message*"`is_pure_julia` must return true or false")
9+
ismissing_or_isa(MLJModelInterface.supports_weights(M), Bool) ||
10+
error(message*"`supports_weights` must return `true`, "*
11+
"`false` or `missing`. ")
12+
ismissing_or_isa(MLJModelInterface.supports_class_weights(M), Bool) ||
13+
error(message*"`supports_class_weights` must return `true`, "*
14+
"`false` or `missing`. ")
15+
MLJModelInterface.is_wrapper(M) isa Bool ||
16+
error(message*"`is_wrapper` must return `true` or `false`. ")
17+
load_path = MLJModelInterface.load_path(M)
18+
load_path isa String ||
19+
error(message*"`load_path` must return a `String`. ")
20+
contains(load_path, "unknown") &&
21+
error(message*"`load_path` return value contains string \"unknown\". ")
22+
pkg = MLJModelInterface.package_name(M)
23+
pkg isa String || error(message*"`package_name` must return a `String`. ")
24+
api_pkg = split(load_path, '.') |> first
25+
pkg == "unknown" && error(message*"`package_name` returns \"unknown\". ")
26+
load_path_ex = Meta.parse(load_path)
27+
api_pkg_ex = Symbol(api_pkg)
28+
import_ex = :(import $api_pkg_ex)
29+
program_to_test_load_path = quote
30+
$import_ex
31+
$load_path_ex
32+
end
33+
try
34+
Main.eval(program_to_test_load_path)
35+
catch excptn
36+
error(message*"Cannot import value of `load_path` (parsed as expression). ")
37+
rethrow(excptn)
38+
end
39+
end

src/methods.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
const DOC_ADDING_PACKAGES = """
2+
3+
Use Julia's package manager, `Pkg`, to add a new model-providing package to the
4+
environment at "/registry/". This is necessary before model medata can be generated
5+
and recorded.
6+
7+
IMPORTANT: In any pull request to update the Model Registry you should note the final
8+
output of `Pkg.status(outdated=true)`.
9+
10+
In detail, you will generally perform the following steps to add the new package:
11+
12+
1. In your local clone of MLJModelRegistry.jl, `activate` the environment at "/registry/".
13+
14+
2. `update` the environment
15+
16+
3. Note the output of `Pkg.status(outdated=true)`
17+
18+
3. `add` the new package
19+
20+
4. Repeat steps 2 and 3 above, and investigate any dependeny downgrades for which your addition may be responsible.
21+
22+
If adding the new package results in downgrades to existing dependencies, your pull
23+
request to register the new models may be rejected.
24+
25+
"""
26+
27+
err_missing_package(pkg) = ArgumentError("""
28+
The package \"$pkg\" could not be found in the model registry. $DOC_ADDING_PACKAGES
29+
"""
30+
)
31+
32+
function metadata(pkg)
33+
pkg in GenericRegistry.dependencies(REGISTRY) || throw(err_missing_package(pkg))
34+
setup = quote
35+
# REMOVE THIS NEXT LINE AFTER TAGGING NEW MLJMODELINTERFACE
36+
Pkg.develop(path="/Users/anthony/MLJ/MLJModelInterface/")
37+
Pkg.develop(path=$ROOT)
38+
end
39+
program = quote
40+
import MLJModelRegistry
41+
MLJModelRegistry.traits_given_constructor_name()
42+
end
43+
return GenericRegistry.run(setup, pkg, program)
44+
end

src/remote_methods.jl

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Remote methods are methods called on remote processes for the purpose of when extacting
2+
# model metadata for a package
3+
4+
5+
# # HELPERS
6+
7+
function finaltypes(T::Type)
8+
s = InteractiveUtils.subtypes(T)
9+
if isempty(s)
10+
return [T, ]
11+
else
12+
return reduce(vcat, [finaltypes(S) for S in s])
13+
end
14+
end
15+
16+
"""
17+
model_type_given_constructor(modeltypes)
18+
19+
**Private method.**
20+
21+
Return a dictionary of `modeltypes`, keyed on constructor. Where multiple types share a
22+
single constructor, there can only be one value (and which value appears is not
23+
predictable).
24+
25+
Typically a model type and it's constructor have the same name, but for wrappers, such as
26+
`TunedModel`, several types share the same constructor (e.g., `DeterministicTunedModel`,
27+
`ProbabilisticTunedModel` are model types sharing constructor `TunedModel`).
28+
29+
"""
30+
function modeltype_given_constructor(modeltypes)
31+
32+
# Note that wrappers are required to overload `MLJModelInterface.constructor` and the
33+
# fallback is `nothing`.
34+
35+
return Dict(
36+
map(modeltypes) do M
37+
C = MLJModelInterface.constructor(M)
38+
Pair(isnothing(C) ? M : C, M)
39+
end...,
40+
)
41+
end
42+
43+
"""
44+
encode_dic(d)
45+
46+
Convert an arbitrary nested dictionary `d` into a nested dictionary whose leaf values are
47+
all strings, suitable for writing to a TOML file (a poor man's serialization). The rules
48+
for converting leaves are:
49+
50+
1. If it's a `Symbol`, preserve the colon, as in :x -> ":x"
51+
52+
2. If it's an `AbstractString`, apply `string` function (e.g, to remove `SubString`s)
53+
54+
3. In all other cases, except `AbstractArray`s, wrap in single quotes, as in sum -> "`sum`"
55+
56+
4. Replace any `#` character in the application of Rule 3 with `_` (to handle `gensym` names)
57+
58+
5. For an `AbstractVector`, broadcast the preceding Rules over its elements.
59+
60+
"""
61+
function encode_dic(s)
62+
prestring = string("`", s, "`")
63+
# hack for objects with gensyms in their string representation:
64+
str = replace(prestring, '#'=>'_')
65+
return str
66+
end
67+
encode_dic(s::Symbol) = string(":", s)
68+
encode_dic(s::AbstractString) = string(s)
69+
encode_dic(v::AbstractVector) = encode_dic.(v)
70+
function encode_dic(d::AbstractDict)
71+
ret = LittleDict{}()
72+
for (k, v) in d
73+
ret[encode_dic(k)] = encode_dic(v)
74+
end
75+
return ret
76+
end
77+
78+
79+
# # REMOTE METHODS
80+
81+
function traits_given_constructor_name()
82+
83+
# Some explanation for the gymnamstics going on here: The model registry is actually
84+
# keyed on constructor names, not model type names, a change from the way the registry
85+
# was initially set up. These are usually the same, but wrappers frequently provide
86+
# exceptions; e.g., "TunedModel" is a constructor for two model types
87+
# "ProbabilisticTunedModel" and "DeterministicTunedModel". Unfortunately, what is easy
88+
# to grab are the model type names (we look for subtypes of `Model`) and we get the
89+
# constructors after, through the `constructor` trait. Only one
90+
91+
modeltypes = filter(finaltypes(MLJModelInterface.Model)) do T
92+
!(isabstracttype(T))
93+
end
94+
modeltype_given_constructor = MLJModelRegistry.modeltype_given_constructor(modeltypes)
95+
constructors = keys(modeltype_given_constructor) |> collect
96+
sort!(constructors, by=string)
97+
traits_given_constructor_name = Dict{String,Any}()
98+
99+
for C in constructors
100+
M = modeltype_given_constructor[C]
101+
check_traits(M)
102+
constructor_name = split(string(C), '.') |> last
103+
traits = LittleDict{Symbol,Any}(trait => eval(:(MLJModelInterface.$trait))(M)
104+
for trait in MLJModelInterface.MODEL_TRAITS)
105+
traits[:name] = constructor_name
106+
traits_given_constructor_name[constructor_name] = traits
107+
end
108+
109+
return encode_dic(traits_given_constructor_name)
110+
end

0 commit comments

Comments
 (0)