Skip to content

Commit d05c10a

Browse files
committed
✍🏻 Complete rename
1 parent d2de32a commit d05c10a

17 files changed

+2621
-0
lines changed
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
2+
@with_kw_noshow mutable struct ContinuousEncoder <: Unsupervised
3+
drop_last::Bool = false
4+
one_hot_ordered_factors::Bool = false
5+
end
6+
7+
function MMI.fit(transformer::ContinuousEncoder, verbosity::Int, X)
8+
9+
# what features can be converted and therefore kept?
10+
s = schema(X)
11+
features = s.names
12+
scitypes = s.scitypes
13+
Convertible = Union{Continuous, Finite, Count}
14+
feature_scitype_tuples = zip(features, scitypes) |> collect
15+
features_to_keep =
16+
first.(filter(t -> last(t) <: Convertible, feature_scitype_tuples))
17+
features_to_be_dropped = setdiff(collect(features), features_to_keep)
18+
19+
if verbosity > 0
20+
if !isempty(features_to_be_dropped)
21+
@info "Some features cannot be replaced with "*
22+
"`Continuous` features and will be dropped: "*
23+
"$features_to_be_dropped. "
24+
end
25+
end
26+
27+
# fit the one-hot encoder:
28+
hot_encoder =
29+
OneHotEncoder(ordered_factor=transformer.one_hot_ordered_factors,
30+
drop_last=transformer.drop_last)
31+
hot_fitresult, _, hot_report = MMI.fit(hot_encoder, verbosity - 1, X)
32+
33+
new_features = setdiff(hot_report.new_features, features_to_be_dropped)
34+
35+
fitresult = (features_to_keep=features_to_keep,
36+
one_hot_encoder=hot_encoder,
37+
one_hot_encoder_fitresult=hot_fitresult)
38+
39+
# generate the report:
40+
report = (features_to_keep=features_to_keep,
41+
new_features=new_features)
42+
43+
cache = nothing
44+
45+
return fitresult, cache, report
46+
47+
end
48+
49+
MMI.fitted_params(::ContinuousEncoder, fitresult) = fitresult
50+
51+
function MMI.transform(transformer::ContinuousEncoder, fitresult, X)
52+
53+
features_to_keep, hot_encoder, hot_fitresult = values(fitresult)
54+
55+
# dump unseen or untransformable features:
56+
if !issubset(features_to_keep, MMI.schema(X).names)
57+
throw(
58+
ArgumentError(
59+
"Supplied frame does not admit previously selected features."
60+
)
61+
)
62+
end
63+
X0 = MMI.selectcols(X, features_to_keep)
64+
65+
# one-hot encode:
66+
X1 = transform(hot_encoder, hot_fitresult, X0)
67+
68+
# convert remaining to continuous:
69+
return coerce(X1, Count=>Continuous, OrderedFactor=>Continuous)
70+
71+
end
72+
73+
metadata_model(ContinuousEncoder,
74+
input_scitype = Table,
75+
output_scitype = Table(Continuous),
76+
load_path = "MLJModels.ContinuousEncoder")
77+
78+
"""
79+
$(MLJModelInterface.doc_header(ContinuousEncoder))
80+
81+
Use this model to arrange all features (features) of a table to have
82+
`Continuous` element scitype, by applying the following protocol to
83+
each feature `ftr`:
84+
85+
- If `ftr` is already `Continuous` retain it.
86+
87+
- If `ftr` is `Multiclass`, one-hot encode it.
88+
89+
- If `ftr` is `OrderedFactor`, replace it with `coerce(ftr,
90+
Continuous)` (vector of floating point integers), unless
91+
`ordered_factors=false` is specified, in which case one-hot encode
92+
it.
93+
94+
- If `ftr` is `Count`, replace it with `coerce(ftr, Continuous)`.
95+
96+
- If `ftr` has some other element scitype, or was not observed in
97+
fitting the encoder, drop it from the table.
98+
99+
**Warning:** This transformer assumes that `levels(col)` for any
100+
`Multiclass` or `OrderedFactor` column, `col`, is the same for
101+
training data and new data to be transformed.
102+
103+
To selectively one-hot-encode categorical features (without dropping
104+
features) use [`OneHotEncoder`](@ref) instead.
105+
106+
107+
# Training data
108+
109+
In MLJ or MLJBase, bind an instance `model` to data with
110+
111+
mach = machine(model, X)
112+
113+
where
114+
115+
- `X`: any Tables.jl compatible table. features can be of mixed type
116+
but only those with element scitype `Multiclass` or `OrderedFactor`
117+
can be encoded. Check column scitypes with `schema(X)`.
118+
119+
Train the machine using `fit!(mach, rows=...)`.
120+
121+
122+
# Hyper-parameters
123+
124+
- `drop_last=true`: whether to drop the column corresponding to the
125+
final class of one-hot encoded features. For example, a three-class
126+
feature is spawned into three new features if `drop_last=false`, but
127+
two just features otherwise.
128+
129+
- `one_hot_ordered_factors=false`: whether to one-hot any feature
130+
with `OrderedFactor` element scitype, or to instead coerce it
131+
directly to a (single) `Continuous` feature using the order
132+
133+
134+
# Fitted parameters
135+
136+
The fields of `fitted_params(mach)` are:
137+
138+
- `features_to_keep`: names of features that will not be dropped from
139+
the table
140+
141+
- `one_hot_encoder`: the `OneHotEncoder` model instance for handling
142+
the one-hot encoding
143+
144+
- `one_hot_encoder_fitresult`: the fitted parameters of the
145+
`OneHotEncoder` model
146+
147+
148+
# Report
149+
150+
- `features_to_keep`: names of input features that will not be dropped
151+
from the table
152+
153+
- `new_features`: names of all output features
154+
155+
156+
# Example
157+
158+
```julia
159+
X = (name=categorical(["Danesh", "Lee", "Mary", "John"]),
160+
grade=categorical(["A", "B", "A", "C"], ordered=true),
161+
height=[1.85, 1.67, 1.5, 1.67],
162+
n_devices=[3, 2, 4, 3],
163+
comments=["the force", "be", "with you", "too"])
164+
165+
julia> schema(X)
166+
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
167+
β”‚ names β”‚ scitypes β”‚
168+
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
169+
β”‚ name β”‚ Multiclass{4} β”‚
170+
β”‚ grade β”‚ OrderedFactor{3} β”‚
171+
β”‚ height β”‚ Continuous β”‚
172+
β”‚ n_devices β”‚ Count β”‚
173+
β”‚ comments β”‚ Textual β”‚
174+
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
175+
176+
encoder = ContinuousEncoder(drop_last=true)
177+
mach = fit!(machine(encoder, X))
178+
W = transform(mach, X)
179+
180+
julia> schema(W)
181+
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
182+
β”‚ names β”‚ scitypes β”‚
183+
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
184+
β”‚ name__Danesh β”‚ Continuous β”‚
185+
β”‚ name__John β”‚ Continuous β”‚
186+
β”‚ name__Lee β”‚ Continuous β”‚
187+
β”‚ grade β”‚ Continuous β”‚
188+
β”‚ height β”‚ Continuous β”‚
189+
β”‚ n_devices β”‚ Continuous β”‚
190+
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
191+
192+
julia> setdiff(schema(X).names, report(mach).features_to_keep) # dropped features
193+
1-element Vector{Symbol}:
194+
:comments
195+
196+
```
197+
198+
See also [`OneHotEncoder`](@ref)
199+
"""
200+
ContinuousEncoder

0 commit comments

Comments
Β (0)