Skip to content

Commit 060afd9

Browse files
committed
⭐️ Add basic extension feats for entity embedders
1 parent 924df68 commit 060afd9

File tree

4 files changed

+177
-8
lines changed

4 files changed

+177
-8
lines changed

Project.toml

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,21 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1919
TableOperations = "ab02a1b2-a7df-11e8-156e-fb1833f50b87"
2020
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
2121

22+
[weakdeps]
23+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
24+
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
25+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
26+
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
27+
28+
[extensions]
29+
EntityEmbeddingsExt = ["MLJFlux", "Optimisers", "MLJBase", "Flux"]
30+
2231
[compat]
2332
CategoricalArrays = "0.10"
33+
Flux = "0.14.25"
34+
MLJFlux = "0.6.0"
2435
MLJModelInterface = "1.11"
36+
Optimisers = "0.3.4"
2537
ScientificTypes = "3.0"
2638
StatsBase = "0.34"
2739
TableOperations = "1.2"
@@ -30,11 +42,11 @@ julia = "1.6.7"
3042

3143
[extras]
3244
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
33-
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
3445
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
3546
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
36-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
47+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
3748
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
49+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3850

3951
[targets]
4052
test = ["Test", "DataFrames", "MLJBase", "Random", "StableRNGs", "StatsModels"]

ext/EntityEmbeddingsExt.jl

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
module EntityEmbeddingsExt
2+
3+
using MLJFlux
4+
using Tables
5+
using ScientificTypes
6+
using MLJModelInterface
7+
using TableOperations
8+
using Optimisers
9+
using Flux
10+
using MLJBase
11+
using MLJTransforms
12+
using MLJTransforms: EntityEmbedder
13+
const MMI = MLJModelInterface
14+
15+
# activations
16+
function MLJTransforms.get_activation(func_symbol::Symbol)
17+
if hasproperty(Flux, func_symbol)
18+
return getproperty(Flux, func_symbol)
19+
else
20+
error("Function $func_symbol not found in Flux.")
21+
end
22+
end
23+
24+
function MLJTransforms.entity_embedder_fit(
25+
X,
26+
y,
27+
features::AbstractVector{Symbol} = Symbol[];
28+
ignore::Bool = true,
29+
hidden_layer_sizes::Tuple{Vararg{Int}} = (5,),
30+
activation::Symbol = :relu,
31+
epochs = 100,
32+
batch_size = 32,
33+
learning_rate = 0.01,
34+
embedding_dims::Dict{Symbol, Real} = Dict{Symbol, Real}(),
35+
verbosity::Int = 0,
36+
kwargs...,
37+
)
38+
39+
# Figure out task
40+
y_scitype = elscitype(y)
41+
classification_types = (y_scitype <: Multiclass || y_scitype <: OrderedFactor)
42+
regression_types = (y_scitype <: Continuous || y_scitype <: Count)
43+
task =
44+
regression_types ? :Regression :
45+
classification_types ? :Classification : :Unknown
46+
task == :Unknown && error(
47+
"Your target must be Continuous/Count for regression or Multiclass/OrderedFactor for classification",
48+
)
49+
50+
# Handle ignore and given feat names
51+
feat_names_org = Tables.schema(X).names
52+
feat_names =
53+
(ignore) ? setdiff(feat_names_org, features) : intersect(feat_names_org, features)
54+
55+
feat_inds_cat = [
56+
findfirst(feat_names .== feat_name) for
57+
feat_name in feat_names if elscitype(Tables.getcolumn(X, feat_name)) <: Finite
58+
]
59+
60+
# Select only the relevant columns in `X` based on `feat_names`
61+
X = X |> TableOperations.select(feat_names...) |> Tables.columntable
62+
63+
64+
# Setup builder
65+
builder = MLJFlux.MLP(;
66+
hidden = hidden_layer_sizes,
67+
σ = MLJTransforms.get_activation(activation),
68+
)
69+
70+
# Accordingly fit NeuralNetworkRegressor, NeuralNetworkClassifier
71+
clf =
72+
(task == :Classification) ?
73+
MLJFlux.NeuralNetworkClassifier(
74+
builder = builder,
75+
optimiser = Optimisers.Adam(learning_rate),
76+
batch_size = batch_size,
77+
epochs = epochs,
78+
embedding_dims = embedding_dims;
79+
kwargs...,
80+
) :
81+
MLJFlux.NeuralNetworkRegressor(
82+
builder = builder,
83+
optimiser = Optimisers.Adam(learning_rate),
84+
batch_size = batch_size,
85+
epochs = epochs,
86+
embedding_dims = embedding_dims;
87+
kwargs...,
88+
)
89+
90+
# Fit the model
91+
mach = machine(clf, X, y)
92+
fit!(mach, verbosity = verbosity)
93+
94+
# Get mappings
95+
96+
mapping_matrices = MLJFlux.get_embedding_matrices(
97+
fitted_params(mach).chain,
98+
feat_inds_cat,
99+
feat_names,
100+
)
101+
ordinal_mappings = mach.fitresult[3]
102+
cache = (
103+
mapping_matrices = mapping_matrices,
104+
ordinal_mappings = ordinal_mappings,
105+
task = task,
106+
machine = mach,
107+
)
108+
return cache
109+
end
110+
111+
112+
"""
113+
Given X and a dict of mapping_matrices that map each categorical column to a matrix, use the matrix to transform
114+
each level in each categorical columns using the columns of the matrix.
115+
116+
This is used with the embedding matrices of the entity embedding layer in entity enabled models to implement entity embeddings.
117+
"""
118+
function MLJTransforms.entity_embedder_transform(X, cache)
119+
mach = cache[:machine]
120+
Xnew = MLJFlux.transform(mach, X)
121+
return Xnew
122+
end
123+
124+
include("EntityEmbeddingsInterface.jl")
125+
126+
127+
end

newmeh/Project.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[deps]
2+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
3+
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
4+
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
5+
MLJTransforms = "23777cdb-d90c-4eb0-a694-7c2b83d5c1d6"
6+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
7+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

src/MLJTransforms.jl

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@ using LinearAlgebra
1010

1111
# Other transformers
1212
using Combinatorics
13-
import Distributions
13+
using Distributions: Distributions
1414
using Parameters
1515
using Dates
1616
using OrderedCollections
1717

1818

19+
1920
const MMI = MLJModelInterface
2021

2122
# Functions of generic use across transformers
@@ -26,27 +27,27 @@ include("utils.jl")
2627
include("encoders/target_encoding/errors.jl")
2728
include("encoders/target_encoding/target_encoding.jl")
2829
include("encoders/target_encoding/interface_mlj.jl")
29-
export TargetEncoder
30+
export TargetEncoder
3031

3132
# Ordinal encoding
3233
include("encoders/ordinal_encoding/ordinal_encoding.jl")
3334
include("encoders/ordinal_encoding/interface_mlj.jl")
34-
export OrdinalEncoder
35+
export OrdinalEncoder
3536

3637
# Frequency encoding
3738
include("encoders/frequency_encoding/frequency_encoding.jl")
3839
include("encoders/frequency_encoding/interface_mlj.jl")
3940
export frequency_encoder_fit, frequency_encoder_transform, FrequencyEncoder
40-
export FrequencyEncoder
41+
export FrequencyEncoder
4142

4243
# Cardinality reduction
4344
include("transformers/cardinality_reducer/cardinality_reducer.jl")
4445
include("transformers/cardinality_reducer/interface_mlj.jl")
4546
export cardinality_reducer_fit, cardinality_reducer_transform, CardinalityReducer
46-
export CardinalityReducer
47+
export CardinalityReducer
4748
include("encoders/missingness_encoding/missingness_encoding.jl")
4849
include("encoders/missingness_encoding/interface_mlj.jl")
49-
export MissingnessEncoder
50+
export MissingnessEncoder
5051

5152
# Contrast encoder
5253
include("encoders/contrast_encoder/contrast_encoder.jl")
@@ -69,3 +70,25 @@ export UnivariateDiscretizer,
6970
OneHotEncoder, ContinuousEncoder, FillImputer, UnivariateFillImputer,
7071
UnivariateTimeTypeToContinuous, InteractionTransformer
7172
end
73+
74+
# For the extension
75+
function get_activation end
76+
function entity_embedder_fit end
77+
function entity_embedder_transform end
78+
79+
mutable struct EntityEmbedder{AS <: AbstractVector{Symbol},
80+
TV <: Tuple{Vararg{Int}},
81+
I1 <: Integer, I2 <: Integer, AF <: AbstractFloat,
82+
DSR <: Dict{Symbol, Real}, I3 <: Int} <: Unsupervised
83+
features::AS
84+
ignore::Bool
85+
hidden_layer_sizes::TV
86+
activation::Symbol
87+
epochs::I1
88+
batch_size::I2
89+
learning_rate::AF
90+
embedding_dims::DSR
91+
verbosity::I3
92+
end
93+
94+
function EntityEmbedder end

0 commit comments

Comments
 (0)