|
| 1 | +# Remote methods are methods called on remote processes for the purpose of when extacting |
| 2 | +# model metadata for a package |
| 3 | + |
| 4 | + |
| 5 | +# # HELPERS |
| 6 | + |
| 7 | +function finaltypes(T::Type) |
| 8 | + s = InteractiveUtils.subtypes(T) |
| 9 | + if isempty(s) |
| 10 | + return [T, ] |
| 11 | + else |
| 12 | + return reduce(vcat, [finaltypes(S) for S in s]) |
| 13 | + end |
| 14 | +end |
| 15 | + |
| 16 | +""" |
| 17 | + model_type_given_constructor(modeltypes) |
| 18 | +
|
| 19 | +**Private method.** |
| 20 | +
|
| 21 | +Return a dictionary of `modeltypes`, keyed on constructor. Where multiple types share a |
| 22 | +single constructor, there can only be one value (and which value appears is not |
| 23 | +predictable). |
| 24 | +
|
| 25 | +Typically a model type and it's constructor have the same name, but for wrappers, such as |
| 26 | +`TunedModel`, several types share the same constructor (e.g., `DeterministicTunedModel`, |
| 27 | +`ProbabilisticTunedModel` are model types sharing constructor `TunedModel`). |
| 28 | +
|
| 29 | +""" |
| 30 | +function modeltype_given_constructor(modeltypes) |
| 31 | + |
| 32 | + # Note that wrappers are required to overload `MLJModelInterface.constructor` and the |
| 33 | + # fallback is `nothing`. |
| 34 | + |
| 35 | + return Dict( |
| 36 | + map(modeltypes) do M |
| 37 | + C = MLJModelInterface.constructor(M) |
| 38 | + Pair(isnothing(C) ? M : C, M) |
| 39 | + end..., |
| 40 | + ) |
| 41 | +end |
| 42 | + |
| 43 | +""" |
| 44 | + encode_dic(d) |
| 45 | +
|
| 46 | +Convert an arbitrary nested dictionary `d` into a nested dictionary whose leaf values are |
| 47 | +all strings, suitable for writing to a TOML file (a poor man's serialization). The rules |
| 48 | +for converting leaves are: |
| 49 | +
|
| 50 | +1. If it's a `Symbol`, preserve the colon, as in :x -> ":x" |
| 51 | +
|
| 52 | +2. If it's an `AbstractString`, apply `string` function (e.g, to remove `SubString`s) |
| 53 | +
|
| 54 | +3. In all other cases, except `AbstractArray`s, wrap in single quotes, as in sum -> "`sum`" |
| 55 | +
|
| 56 | +4. Replace any `#` character in the application of Rule 3 with `_` (to handle `gensym` names) |
| 57 | +
|
| 58 | +5. For an `AbstractVector`, broadcast the preceding Rules over its elements. |
| 59 | +
|
| 60 | +""" |
| 61 | +function encode_dic(s) |
| 62 | + prestring = string("`", s, "`") |
| 63 | + # hack for objects with gensyms in their string representation: |
| 64 | + str = replace(prestring, '#'=>'_') |
| 65 | + return str |
| 66 | +end |
| 67 | +encode_dic(s::Symbol) = string(":", s) |
| 68 | +encode_dic(s::AbstractString) = string(s) |
| 69 | +encode_dic(v::AbstractVector) = encode_dic.(v) |
| 70 | +function encode_dic(d::AbstractDict) |
| 71 | + ret = LittleDict{}() |
| 72 | + for (k, v) in d |
| 73 | + ret[encode_dic(k)] = encode_dic(v) |
| 74 | + end |
| 75 | + return ret |
| 76 | +end |
| 77 | + |
| 78 | + |
| 79 | +# # REMOTE METHODS |
| 80 | + |
| 81 | +function traits_given_constructor_name() |
| 82 | + |
| 83 | + # Some explanation for the gymnamstics going on here: The model registry is actually |
| 84 | + # keyed on constructor names, not model type names, a change from the way the registry |
| 85 | + # was initially set up. These are usually the same, but wrappers frequently provide |
| 86 | + # exceptions; e.g., "TunedModel" is a constructor for two model types |
| 87 | + # "ProbabilisticTunedModel" and "DeterministicTunedModel". Unfortunately, what is easy |
| 88 | + # to grab are the model type names (we look for subtypes of `Model`) and we get the |
| 89 | + # constructors after, through the `constructor` trait. Only one |
| 90 | + |
| 91 | + modeltypes = filter(finaltypes(MLJModelInterface.Model)) do T |
| 92 | + !(isabstracttype(T)) |
| 93 | + end |
| 94 | + modeltype_given_constructor = MLJModelRegistry.modeltype_given_constructor(modeltypes) |
| 95 | + constructors = keys(modeltype_given_constructor) |> collect |
| 96 | + sort!(constructors, by=string) |
| 97 | + traits_given_constructor_name = Dict{String,Any}() |
| 98 | + |
| 99 | + for C in constructors |
| 100 | + M = modeltype_given_constructor[C] |
| 101 | + check_traits(M) |
| 102 | + constructor_name = split(string(C), '.') |> last |
| 103 | + traits = LittleDict{Symbol,Any}(trait => eval(:(MLJModelInterface.$trait))(M) |
| 104 | + for trait in MLJModelInterface.MODEL_TRAITS) |
| 105 | + traits[:name] = constructor_name |
| 106 | + traits_given_constructor_name[constructor_name] = traits |
| 107 | + end |
| 108 | + |
| 109 | + return encode_dic(traits_given_constructor_name) |
| 110 | +end |
0 commit comments