|
| 1 | + |
| 2 | +@with_kw_noshow mutable struct ContinuousEncoder <: Unsupervised |
| 3 | + drop_last::Bool = false |
| 4 | + one_hot_ordered_factors::Bool = false |
| 5 | +end |
| 6 | + |
| 7 | +function MMI.fit(transformer::ContinuousEncoder, verbosity::Int, X) |
| 8 | + |
| 9 | + # what features can be converted and therefore kept? |
| 10 | + s = schema(X) |
| 11 | + features = s.names |
| 12 | + scitypes = s.scitypes |
| 13 | + Convertible = Union{Continuous, Finite, Count} |
| 14 | + feature_scitype_tuples = zip(features, scitypes) |> collect |
| 15 | + features_to_keep = |
| 16 | + first.(filter(t -> last(t) <: Convertible, feature_scitype_tuples)) |
| 17 | + features_to_be_dropped = setdiff(collect(features), features_to_keep) |
| 18 | + |
| 19 | + if verbosity > 0 |
| 20 | + if !isempty(features_to_be_dropped) |
| 21 | + @info "Some features cannot be replaced with "* |
| 22 | + "`Continuous` features and will be dropped: "* |
| 23 | + "$features_to_be_dropped. " |
| 24 | + end |
| 25 | + end |
| 26 | + |
| 27 | + # fit the one-hot encoder: |
| 28 | + hot_encoder = |
| 29 | + OneHotEncoder(ordered_factor=transformer.one_hot_ordered_factors, |
| 30 | + drop_last=transformer.drop_last) |
| 31 | + hot_fitresult, _, hot_report = MMI.fit(hot_encoder, verbosity - 1, X) |
| 32 | + |
| 33 | + new_features = setdiff(hot_report.new_features, features_to_be_dropped) |
| 34 | + |
| 35 | + fitresult = (features_to_keep=features_to_keep, |
| 36 | + one_hot_encoder=hot_encoder, |
| 37 | + one_hot_encoder_fitresult=hot_fitresult) |
| 38 | + |
| 39 | + # generate the report: |
| 40 | + report = (features_to_keep=features_to_keep, |
| 41 | + new_features=new_features) |
| 42 | + |
| 43 | + cache = nothing |
| 44 | + |
| 45 | + return fitresult, cache, report |
| 46 | + |
| 47 | +end |
| 48 | + |
| 49 | +MMI.fitted_params(::ContinuousEncoder, fitresult) = fitresult |
| 50 | + |
| 51 | +function MMI.transform(transformer::ContinuousEncoder, fitresult, X) |
| 52 | + |
| 53 | + features_to_keep, hot_encoder, hot_fitresult = values(fitresult) |
| 54 | + |
| 55 | + # dump unseen or untransformable features: |
| 56 | + if !issubset(features_to_keep, MMI.schema(X).names) |
| 57 | + throw( |
| 58 | + ArgumentError( |
| 59 | + "Supplied frame does not admit previously selected features." |
| 60 | + ) |
| 61 | + ) |
| 62 | + end |
| 63 | + X0 = MMI.selectcols(X, features_to_keep) |
| 64 | + |
| 65 | + # one-hot encode: |
| 66 | + X1 = transform(hot_encoder, hot_fitresult, X0) |
| 67 | + |
| 68 | + # convert remaining to continuous: |
| 69 | + return coerce(X1, Count=>Continuous, OrderedFactor=>Continuous) |
| 70 | + |
| 71 | +end |
| 72 | + |
| 73 | +metadata_model(ContinuousEncoder, |
| 74 | + input_scitype = Table, |
| 75 | + output_scitype = Table(Continuous), |
| 76 | + load_path = "MLJModels.ContinuousEncoder") |
| 77 | + |
| 78 | +""" |
| 79 | +$(MLJModelInterface.doc_header(ContinuousEncoder)) |
| 80 | +
|
| 81 | +Use this model to arrange all features (features) of a table to have |
| 82 | +`Continuous` element scitype, by applying the following protocol to |
| 83 | +each feature `ftr`: |
| 84 | +
|
| 85 | +- If `ftr` is already `Continuous` retain it. |
| 86 | +
|
| 87 | +- If `ftr` is `Multiclass`, one-hot encode it. |
| 88 | +
|
| 89 | +- If `ftr` is `OrderedFactor`, replace it with `coerce(ftr, |
| 90 | + Continuous)` (vector of floating point integers), unless |
| 91 | + `ordered_factors=false` is specified, in which case one-hot encode |
| 92 | + it. |
| 93 | +
|
| 94 | +- If `ftr` is `Count`, replace it with `coerce(ftr, Continuous)`. |
| 95 | +
|
| 96 | +- If `ftr` has some other element scitype, or was not observed in |
| 97 | + fitting the encoder, drop it from the table. |
| 98 | +
|
| 99 | +**Warning:** This transformer assumes that `levels(col)` for any |
| 100 | +`Multiclass` or `OrderedFactor` column, `col`, is the same for |
| 101 | +training data and new data to be transformed. |
| 102 | +
|
| 103 | +To selectively one-hot-encode categorical features (without dropping |
| 104 | +features) use [`OneHotEncoder`](@ref) instead. |
| 105 | +
|
| 106 | +
|
| 107 | +# Training data |
| 108 | +
|
| 109 | +In MLJ or MLJBase, bind an instance `model` to data with |
| 110 | +
|
| 111 | + mach = machine(model, X) |
| 112 | +
|
| 113 | +where |
| 114 | +
|
| 115 | +- `X`: any Tables.jl compatible table. features can be of mixed type |
| 116 | + but only those with element scitype `Multiclass` or `OrderedFactor` |
| 117 | + can be encoded. Check column scitypes with `schema(X)`. |
| 118 | +
|
| 119 | +Train the machine using `fit!(mach, rows=...)`. |
| 120 | +
|
| 121 | +
|
| 122 | +# Hyper-parameters |
| 123 | +
|
| 124 | +- `drop_last=true`: whether to drop the column corresponding to the |
| 125 | + final class of one-hot encoded features. For example, a three-class |
| 126 | + feature is spawned into three new features if `drop_last=false`, but |
| 127 | + two just features otherwise. |
| 128 | +
|
| 129 | +- `one_hot_ordered_factors=false`: whether to one-hot any feature |
| 130 | + with `OrderedFactor` element scitype, or to instead coerce it |
| 131 | + directly to a (single) `Continuous` feature using the order |
| 132 | +
|
| 133 | +
|
| 134 | +# Fitted parameters |
| 135 | +
|
| 136 | +The fields of `fitted_params(mach)` are: |
| 137 | +
|
| 138 | +- `features_to_keep`: names of features that will not be dropped from |
| 139 | + the table |
| 140 | +
|
| 141 | +- `one_hot_encoder`: the `OneHotEncoder` model instance for handling |
| 142 | + the one-hot encoding |
| 143 | +
|
| 144 | +- `one_hot_encoder_fitresult`: the fitted parameters of the |
| 145 | + `OneHotEncoder` model |
| 146 | +
|
| 147 | +
|
| 148 | +# Report |
| 149 | +
|
| 150 | +- `features_to_keep`: names of input features that will not be dropped |
| 151 | + from the table |
| 152 | +
|
| 153 | +- `new_features`: names of all output features |
| 154 | +
|
| 155 | +
|
| 156 | +# Example |
| 157 | +
|
| 158 | +```julia |
| 159 | +X = (name=categorical(["Danesh", "Lee", "Mary", "John"]), |
| 160 | + grade=categorical(["A", "B", "A", "C"], ordered=true), |
| 161 | + height=[1.85, 1.67, 1.5, 1.67], |
| 162 | + n_devices=[3, 2, 4, 3], |
| 163 | + comments=["the force", "be", "with you", "too"]) |
| 164 | +
|
| 165 | +julia> schema(X) |
| 166 | +βββββββββββββ¬βββββββββββββββββββ |
| 167 | +β names β scitypes β |
| 168 | +βββββββββββββΌβββββββββββββββββββ€ |
| 169 | +β name β Multiclass{4} β |
| 170 | +β grade β OrderedFactor{3} β |
| 171 | +β height β Continuous β |
| 172 | +β n_devices β Count β |
| 173 | +β comments β Textual β |
| 174 | +βββββββββββββ΄βββββββββββββββββββ |
| 175 | +
|
| 176 | +encoder = ContinuousEncoder(drop_last=true) |
| 177 | +mach = fit!(machine(encoder, X)) |
| 178 | +W = transform(mach, X) |
| 179 | +
|
| 180 | +julia> schema(W) |
| 181 | +ββββββββββββββββ¬βββββββββββββ |
| 182 | +β names β scitypes β |
| 183 | +ββββββββββββββββΌβββββββββββββ€ |
| 184 | +β name__Danesh β Continuous β |
| 185 | +β name__John β Continuous β |
| 186 | +β name__Lee β Continuous β |
| 187 | +β grade β Continuous β |
| 188 | +β height β Continuous β |
| 189 | +β n_devices β Continuous β |
| 190 | +ββββββββββββββββ΄βββββββββββββ |
| 191 | +
|
| 192 | +julia> setdiff(schema(X).names, report(mach).features_to_keep) # dropped features |
| 193 | +1-element Vector{Symbol}: |
| 194 | + :comments |
| 195 | +
|
| 196 | +``` |
| 197 | +
|
| 198 | +See also [`OneHotEncoder`](@ref) |
| 199 | +""" |
| 200 | +ContinuousEncoder |
0 commit comments