Skip to content

Commit 9381b01

Browse files
authored
Merge branch 'main' into Contrast-Encoding
2 parents 7cc4c66 + aaa9b49 commit 9381b01

20 files changed

+2653
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
3333
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
3434
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3535
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
36-
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
3736
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
37+
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
3838

3939
[targets]
4040
test = ["Test", "DataFrames", "MLJBase", "Random", "StableRNGs", "StatsModels"]

src/MLJTransforms.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using MLJModelInterface
77
using TableOperations
88
using StatsBase
99
using LinearAlgebra
10+
1011
# Other transformers
1112
using Combinatorics
1213
import Distributions
@@ -42,8 +43,25 @@ export FrequencyEncoder
4243
include("transformers/cardinality_reducer/cardinality_reducer.jl")
4344
include("transformers/cardinality_reducer/interface_mlj.jl")
4445
export cardinality_reducer_fit, cardinality_reducer_transform, CardinalityReducer
46+
47+
# Contrast encoder
4548
include("encoders/contrast_encoder/contrast_encoder.jl")
4649
include("encoders/contrast_encoder/interface_mlj.jl")
4750
export ContrastEncoder
4851

49-
end
52+
# MLJModels transformers
53+
include("transformers/other_transformers/continuous_encoder.jl")
54+
include("transformers/other_transformers/interaction_transformer.jl")
55+
include("transformers/other_transformers/univariate_time_type_to_continuous.jl")
56+
include("transformers/other_transformers/fill_imputer.jl")
57+
include("transformers/other_transformers/one_hot_encoder.jl")
58+
include("transformers/other_transformers/standardizer.jl")
59+
include("transformers/other_transformers/univariate_boxcox_transformer.jl")
60+
include("transformers/other_transformers/univariate_discretizer.jl")
61+
include("transformers/other_transformers/metadata_shared.jl")
62+
63+
export UnivariateDiscretizer,
64+
UnivariateStandardizer, Standardizer, UnivariateBoxCoxTransformer,
65+
OneHotEncoder, ContinuousEncoder, FillImputer, UnivariateFillImputer,
66+
UnivariateTimeTypeToContinuous, InteractionTransformer
67+
end
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
2+
@with_kw_noshow mutable struct ContinuousEncoder <: Unsupervised
3+
drop_last::Bool = false
4+
one_hot_ordered_factors::Bool = false
5+
end
6+
7+
function MMI.fit(transformer::ContinuousEncoder, verbosity::Int, X)
8+
9+
# what features can be converted and therefore kept?
10+
s = schema(X)
11+
features = s.names
12+
scitypes = s.scitypes
13+
Convertible = Union{Continuous, Finite, Count}
14+
feature_scitype_tuples = zip(features, scitypes) |> collect
15+
features_to_keep =
16+
first.(filter(t -> last(t) <: Convertible, feature_scitype_tuples))
17+
features_to_be_dropped = setdiff(collect(features), features_to_keep)
18+
19+
if verbosity > 0
20+
if !isempty(features_to_be_dropped)
21+
@info "Some features cannot be replaced with "*
22+
"`Continuous` features and will be dropped: "*
23+
"$features_to_be_dropped. "
24+
end
25+
end
26+
27+
# fit the one-hot encoder:
28+
hot_encoder =
29+
OneHotEncoder(ordered_factor=transformer.one_hot_ordered_factors,
30+
drop_last=transformer.drop_last)
31+
hot_fitresult, _, hot_report = MMI.fit(hot_encoder, verbosity - 1, X)
32+
33+
new_features = setdiff(hot_report.new_features, features_to_be_dropped)
34+
35+
fitresult = (features_to_keep=features_to_keep,
36+
one_hot_encoder=hot_encoder,
37+
one_hot_encoder_fitresult=hot_fitresult)
38+
39+
# generate the report:
40+
report = (features_to_keep=features_to_keep,
41+
new_features=new_features)
42+
43+
cache = nothing
44+
45+
return fitresult, cache, report
46+
47+
end
48+
49+
MMI.fitted_params(::ContinuousEncoder, fitresult) = fitresult
50+
51+
function MMI.transform(transformer::ContinuousEncoder, fitresult, X)
52+
53+
features_to_keep, hot_encoder, hot_fitresult = values(fitresult)
54+
55+
# dump unseen or untransformable features:
56+
if !issubset(features_to_keep, MMI.schema(X).names)
57+
throw(
58+
ArgumentError(
59+
"Supplied frame does not admit previously selected features."
60+
)
61+
)
62+
end
63+
X0 = MMI.selectcols(X, features_to_keep)
64+
65+
# one-hot encode:
66+
X1 = transform(hot_encoder, hot_fitresult, X0)
67+
68+
# convert remaining to continuous:
69+
return coerce(X1, Count=>Continuous, OrderedFactor=>Continuous)
70+
71+
end
72+
73+
metadata_model(ContinuousEncoder,
74+
input_scitype = Table,
75+
output_scitype = Table(Continuous),
76+
load_path = "MLJModels.ContinuousEncoder")
77+
78+
"""
79+
$(MLJModelInterface.doc_header(ContinuousEncoder))
80+
81+
Use this model to arrange all features (features) of a table to have
82+
`Continuous` element scitype, by applying the following protocol to
83+
each feature `ftr`:
84+
85+
- If `ftr` is already `Continuous` retain it.
86+
87+
- If `ftr` is `Multiclass`, one-hot encode it.
88+
89+
- If `ftr` is `OrderedFactor`, replace it with `coerce(ftr,
90+
Continuous)` (vector of floating point integers), unless
91+
`ordered_factors=false` is specified, in which case one-hot encode
92+
it.
93+
94+
- If `ftr` is `Count`, replace it with `coerce(ftr, Continuous)`.
95+
96+
- If `ftr` has some other element scitype, or was not observed in
97+
fitting the encoder, drop it from the table.
98+
99+
**Warning:** This transformer assumes that `levels(col)` for any
100+
`Multiclass` or `OrderedFactor` column, `col`, is the same for
101+
training data and new data to be transformed.
102+
103+
To selectively one-hot-encode categorical features (without dropping
104+
features) use [`OneHotEncoder`](@ref) instead.
105+
106+
107+
# Training data
108+
109+
In MLJ or MLJBase, bind an instance `model` to data with
110+
111+
mach = machine(model, X)
112+
113+
where
114+
115+
- `X`: any Tables.jl compatible table. features can be of mixed type
116+
but only those with element scitype `Multiclass` or `OrderedFactor`
117+
can be encoded. Check column scitypes with `schema(X)`.
118+
119+
Train the machine using `fit!(mach, rows=...)`.
120+
121+
122+
# Hyper-parameters
123+
124+
- `drop_last=true`: whether to drop the column corresponding to the
125+
final class of one-hot encoded features. For example, a three-class
126+
feature is spawned into three new features if `drop_last=false`, but
127+
two just features otherwise.
128+
129+
- `one_hot_ordered_factors=false`: whether to one-hot any feature
130+
with `OrderedFactor` element scitype, or to instead coerce it
131+
directly to a (single) `Continuous` feature using the order
132+
133+
134+
# Fitted parameters
135+
136+
The fields of `fitted_params(mach)` are:
137+
138+
- `features_to_keep`: names of features that will not be dropped from
139+
the table
140+
141+
- `one_hot_encoder`: the `OneHotEncoder` model instance for handling
142+
the one-hot encoding
143+
144+
- `one_hot_encoder_fitresult`: the fitted parameters of the
145+
`OneHotEncoder` model
146+
147+
148+
# Report
149+
150+
- `features_to_keep`: names of input features that will not be dropped
151+
from the table
152+
153+
- `new_features`: names of all output features
154+
155+
156+
# Example
157+
158+
```julia
159+
X = (name=categorical(["Danesh", "Lee", "Mary", "John"]),
160+
grade=categorical(["A", "B", "A", "C"], ordered=true),
161+
height=[1.85, 1.67, 1.5, 1.67],
162+
n_devices=[3, 2, 4, 3],
163+
comments=["the force", "be", "with you", "too"])
164+
165+
julia> schema(X)
166+
┌───────────┬──────────────────┐
167+
│ names │ scitypes │
168+
├───────────┼──────────────────┤
169+
│ name │ Multiclass{4} │
170+
│ grade │ OrderedFactor{3} │
171+
│ height │ Continuous │
172+
│ n_devices │ Count │
173+
│ comments │ Textual │
174+
└───────────┴──────────────────┘
175+
176+
encoder = ContinuousEncoder(drop_last=true)
177+
mach = fit!(machine(encoder, X))
178+
W = transform(mach, X)
179+
180+
julia> schema(W)
181+
┌──────────────┬────────────┐
182+
│ names │ scitypes │
183+
├──────────────┼────────────┤
184+
│ name__Danesh │ Continuous │
185+
│ name__John │ Continuous │
186+
│ name__Lee │ Continuous │
187+
│ grade │ Continuous │
188+
│ height │ Continuous │
189+
│ n_devices │ Continuous │
190+
└──────────────┴────────────┘
191+
192+
julia> setdiff(schema(X).names, report(mach).features_to_keep) # dropped features
193+
1-element Vector{Symbol}:
194+
:comments
195+
196+
```
197+
198+
See also [`OneHotEncoder`](@ref)
199+
"""
200+
ContinuousEncoder

0 commit comments

Comments
 (0)