1+ module EntityEmbeddingsExt
2+
3+ using MLJFlux
4+ using Tables
5+ using ScientificTypes
6+ using MLJModelInterface
7+ using TableOperations
8+ using Optimisers
9+ using Flux
10+ using MLJBase
11+ using MLJTransforms
12+ using MLJTransforms: EntityEmbedder
13+ const MMI = MLJModelInterface
14+
15+ # activations
16+ function MLJTransforms. get_activation (func_symbol:: Symbol )
17+ if hasproperty (Flux, func_symbol)
18+ return getproperty (Flux, func_symbol)
19+ else
20+ error (" Function $func_symbol not found in Flux." )
21+ end
22+ end
23+
24+ function MLJTransforms. entity_embedder_fit (
25+ X,
26+ y,
27+ features:: AbstractVector{Symbol} = Symbol[];
28+ ignore:: Bool = true ,
29+ hidden_layer_sizes:: Tuple{Vararg{Int}} = (5 ,),
30+ activation:: Symbol = :relu ,
31+ epochs = 100 ,
32+ batch_size = 32 ,
33+ learning_rate = 0.01 ,
34+ embedding_dims:: Dict{Symbol, Real} = Dict {Symbol, Real} (),
35+ verbosity:: Int = 0 ,
36+ kwargs... ,
37+ )
38+
39+ # Figure out task
40+ y_scitype = elscitype (y)
41+ classification_types = (y_scitype <: Multiclass || y_scitype <: OrderedFactor )
42+ regression_types = (y_scitype <: Continuous || y_scitype <: Count )
43+ task =
44+ regression_types ? :Regression :
45+ classification_types ? :Classification : :Unknown
46+ task == :Unknown && error (
47+ " Your target must be Continuous/Count for regression or Multiclass/OrderedFactor for classification" ,
48+ )
49+
50+ # Handle ignore and given feat names
51+ feat_names_org = Tables. schema (X). names
52+ feat_names =
53+ (ignore) ? setdiff (feat_names_org, features) : intersect (feat_names_org, features)
54+
55+ feat_inds_cat = [
56+ findfirst (feat_names .== feat_name) for
57+ feat_name in feat_names if elscitype (Tables. getcolumn (X, feat_name)) <: Finite
58+ ]
59+
60+ # Select only the relevant columns in `X` based on `feat_names`
61+ X = X |> TableOperations. select (feat_names... ) |> Tables. columntable
62+
63+
64+ # Setup builder
65+ builder = MLJFlux. MLP (;
66+ hidden = hidden_layer_sizes,
67+ σ = MLJTransforms. get_activation (activation),
68+ )
69+
70+ # Accordingly fit NeuralNetworkRegressor, NeuralNetworkClassifier
71+ clf =
72+ (task == :Classification ) ?
73+ MLJFlux. NeuralNetworkClassifier (
74+ builder = builder,
75+ optimiser = Optimisers. Adam (learning_rate),
76+ batch_size = batch_size,
77+ epochs = epochs,
78+ embedding_dims = embedding_dims;
79+ kwargs... ,
80+ ) :
81+ MLJFlux. NeuralNetworkRegressor (
82+ builder = builder,
83+ optimiser = Optimisers. Adam (learning_rate),
84+ batch_size = batch_size,
85+ epochs = epochs,
86+ embedding_dims = embedding_dims;
87+ kwargs... ,
88+ )
89+
90+ # Fit the model
91+ mach = machine (clf, X, y)
92+ fit! (mach, verbosity = verbosity)
93+
94+ # Get mappings
95+
96+ mapping_matrices = MLJFlux. get_embedding_matrices (
97+ fitted_params (mach). chain,
98+ feat_inds_cat,
99+ feat_names,
100+ )
101+ ordinal_mappings = mach. fitresult[3 ]
102+ cache = (
103+ mapping_matrices = mapping_matrices,
104+ ordinal_mappings = ordinal_mappings,
105+ task = task,
106+ machine = mach,
107+ )
108+ return cache
109+ end
110+
111+
112+ """
113+ Given X and a dict of mapping_matrices that map each categorical column to a matrix, use the matrix to transform
114+ each level in each categorical columns using the columns of the matrix.
115+
116+ This is used with the embedding matrices of the entity embedding layer in entity enabled models to implement entity embeddings.
117+ """
118+ function MLJTransforms. entity_embedder_transform (X, cache)
119+ mach = cache[:machine ]
120+ Xnew = MLJFlux. transform (mach, X)
121+ return Xnew
122+ end
123+
124+ include (" EntityEmbeddingsInterface.jl" )
125+
126+
127+ end
0 commit comments