Skip to content

Commit f7ac80e

Browse files
committed
✨ improve entity embeddings tutorial
1 parent 09de647 commit f7ac80e

File tree

2 files changed

+29
-30
lines changed

2 files changed

+29
-30
lines changed

docs/src/tutorials/entity_embeddings/notebook.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Pkg.instantiate(); #src
2828

2929
## Import all required packages
3030
using MLJ
31+
using MLJFlux
3132
using CategoricalArrays
3233
using DataFrames
3334
using Optimisers
@@ -160,8 +161,8 @@ df = coerce(df,
160161
Symbol("Content Rating") => Multiclass,
161162
:Genres => Multiclass,
162163
Symbol("Android Ver") => Multiclass,
163-
:Rating => Continuous, ## Keep original for reference
164-
:RatingCategory => Multiclass, ## New categorical target
164+
:Rating => Continuous, ## Keep original for reference
165+
:RatingCategory => OrderedFactor, ## New categorical target
165166
);
166167
schema(df)
167168

@@ -183,7 +184,6 @@ X = select(df, Not([:Rating, :RatingCategory])); ## Exclude both rating columns
183184
rng = Random.Xoshiro(41),
184185
);
185186

186-
using MLJFlux
187187

188188
# ## Building the EntityEmbedder Model
189189

@@ -233,8 +233,8 @@ MLJ.fit!(mach, force = true, verbosity = 1);
233233
# After training, we can use the embedder as a transformer to convert categorical features into their learned embedding representations.
234234

235235
## Transform the data using the learned embeddings
236-
X_train_embedded = MLJFlux.transform(mach, X_train)
237-
X_test_embedded = MLJFlux.transform(mach, X_test);
236+
X_train_embedded = MLJ.transform(mach, X_train)
237+
X_test_embedded = MLJ.transform(mach, X_test);
238238

239239
## Check the schema transformation
240240
println("Original schema:")

docs/src/tutorials/entity_embeddings/notebook.md

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Pkg.activate(@__DIR__);
3232

3333
# Import all required packages
3434
using MLJ
35+
using MLJFlux
3536
using CategoricalArrays
3637
using DataFrames
3738
using Optimisers
@@ -184,7 +185,7 @@ println("\nUnique rating categories: $(sort(unique(df.RatingCategory)))")
184185

185186
````
186187
Distribution of categorical rating labels:
187-
OrderedCollections.OrderedDict{CategoricalArrays.CategoricalValue{String, UInt32}, Int64}("1.0" => 17, "1.5" => 18, "2.0" => 53, "2.5" => 105, "3.0" => 281, "3.5" => 722, "4.0" => 2420, "4.5" => 3542, "5.0" => 571, "NaN" => 1416)
188+
OrderedCollections.OrderedDict{CategoricalValue{String, UInt32}, Int64}("1.0" => 17, "1.5" => 18, "2.0" => 53, "2.5" => 105, "3.0" => 281, "3.5" => 722, "4.0" => 2420, "4.5" => 3542, "5.0" => 571, "NaN" => 1416)
188189
189190
Unique rating categories: ["1.0", "1.5", "2.0", "2.5", "3.0", "3.5", "4.0", "4.5", "5.0", "NaN"]
190191
@@ -207,28 +208,28 @@ df = coerce(df,
207208
Symbol("Content Rating") => Multiclass,
208209
:Genres => Multiclass,
209210
Symbol("Android Ver") => Multiclass,
210-
:Rating => Continuous, ## Keep original for reference
211-
:RatingCategory => Multiclass, ## New categorical target
211+
:Rating => Continuous, ## Keep original for reference
212+
:RatingCategory => OrderedFactor, ## New categorical target
212213
);
213214
schema(df)
214215
````
215216

216217
````
217-
┌────────────────┬────────────────┬────────────────────────────────────┐
218-
│ names │ scitypes │ types │
219-
├────────────────┼────────────────┼────────────────────────────────────┤
220-
│ Category │ Multiclass{33} │ CategoricalValue{String31, UInt32} │
221-
│ Reviews │ Continuous │ Float64 │
222-
│ Size │ Continuous │ Float64 │
223-
│ Installs │ Continuous │ Float64 │
224-
│ Type │ Multiclass{2} │ CategoricalValue{String7, UInt32} │
225-
│ Price │ Continuous │ Float64 │
226-
│ Content Rating │ Multiclass{6} │ CategoricalValue{String15, UInt32} │
227-
│ Genres │ Multiclass{48} │ CategoricalValue{String, UInt32} │
228-
│ Android Ver │ Multiclass{34} │ CategoricalValue{String31, UInt32} │
229-
│ Rating │ Continuous │ Float64 │
230-
│ RatingCategory │ Multiclass{10} │ CategoricalValue{String, UInt32} │
231-
└────────────────┴────────────────┴────────────────────────────────────┘
218+
┌────────────────┬───────────────────┬────────────────────────────────────┐
219+
│ names │ scitypes │ types │
220+
├────────────────┼───────────────────┼────────────────────────────────────┤
221+
│ Category │ Multiclass{33} │ CategoricalValue{String31, UInt32} │
222+
│ Reviews │ Continuous │ Float64 │
223+
│ Size │ Continuous │ Float64 │
224+
│ Installs │ Continuous │ Float64 │
225+
│ Type │ Multiclass{2} │ CategoricalValue{String7, UInt32} │
226+
│ Price │ Continuous │ Float64 │
227+
│ Content Rating │ Multiclass{6} │ CategoricalValue{String15, UInt32} │
228+
│ Genres │ Multiclass{48} │ CategoricalValue{String, UInt32} │
229+
│ Android Ver │ Multiclass{34} │ CategoricalValue{String31, UInt32} │
230+
│ Rating │ Continuous │ Float64 │
231+
│ RatingCategory │ OrderedFactor{10} │ CategoricalValue{String, UInt32} │
232+
└────────────────┴───────────────────┴────────────────────────────────────┘
232233
233234
````
234235

@@ -250,8 +251,6 @@ X = select(df, Not([:Rating, :RatingCategory])); ## Exclude both rating columns
250251
stratify = y,
251252
rng = Random.Xoshiro(41),
252253
);
253-
254-
using MLJFlux
255254
````
256255

257256
## Building the EntityEmbedder Model
@@ -321,7 +320,7 @@ EntityEmbedder(
321320
alpha = 0.0,
322321
rng = 39,
323322
optimiser_changes_trigger_retraining = false,
324-
acceleration = ComputationalResources.CUDALibs{Nothing}(nothing),
323+
acceleration = CUDALibs{Nothing}(nothing),
325324
embedding_dims = Dict{Symbol, Real}(:Category => 2, Symbol("Content Rating") => 2, Symbol("Android Ver") => 2, :Genres => 2, :Type => 2)))
326325
````
327326

@@ -351,8 +350,8 @@ After training, we can use the embedder as a transformer to convert categorical
351350

352351
````julia
353352
# Transform the data using the learned embeddings
354-
X_train_embedded = MLJFlux.transform(mach, X_train)
355-
X_test_embedded = MLJFlux.transform(mach, X_test);
353+
X_train_embedded = MLJ.transform(mach, X_train)
354+
X_test_embedded = MLJ.transform(mach, X_test);
356355

357356
# Check the schema transformation
358357
println("Original schema:")
@@ -389,8 +388,8 @@ MLJ.fit!(pipe_mach, verbosity = 0)
389388
trained Machine; does not cache data
390389
model: ProbabilisticPipeline(entity_embedder = EntityEmbedder(model = NeuralNetworkClassifier(builder = Short(n_hidden = 14, …), …)), …)
391390
args:
392-
1: Source @225ScientificTypesBase.Table{Union{AbstractVector{ScientificTypesBase.Continuous}, AbstractVector{ScientificTypesBase.Multiclass{33}}, AbstractVector{ScientificTypesBase.Multiclass{2}}, AbstractVector{ScientificTypesBase.Multiclass{6}}, AbstractVector{ScientificTypesBase.Multiclass{48}}, AbstractVector{ScientificTypesBase.Multiclass{34}}}}
393-
2: Source @148 ⏎ AbstractVector{ScientificTypesBase.Multiclass{10}}
391+
1: Source @927 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Multiclass{33}}, AbstractVector{Multiclass{2}}, AbstractVector{Multiclass{6}}, AbstractVector{Multiclass{48}}, AbstractVector{Multiclass{34}}}}
392+
2: Source @044 ⏎ AbstractVector{OrderedFactor{10}}
394393
395394
````
396395

0 commit comments

Comments
 (0)