Skip to content

Commit 8340b57

Browse files
authored
Add DistributedEmbedding example for TPU on TensorFlow. (#2174)
This was run on a cloud TPU v6e-1. Also tweaked some comments in the JAX DistributedEmbedding example.
1 parent 7894001 commit 8340b57

File tree

7 files changed

+1347
-15
lines changed

7 files changed

+1347
-15
lines changed

examples/keras_rs/distributed_embedding_jax.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""
22
Title: DistributedEmbedding using TPU SparseCore and JAX
3-
Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
3+
Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/), [C. Antonio Sánchez](https://github.com/cantonios/)
44
Date created: 2025/06/03
5-
Last modified: 2025/06/03
5+
Last modified: 2025/09/02
66
Description: Rank movies using a two tower model with embeddings on SparseCore.
77
Accelerator: TPU
88
"""
@@ -56,7 +56,7 @@
5656
"""
5757
## Preparing the dataset
5858
59-
We're going to use the same Movielens data. The ratings are the objectives we
59+
We're going to use the same MovieLens data. The ratings are the objectives we
6060
are trying to predict.
6161
"""
6262

@@ -150,8 +150,8 @@ def preprocess_rating(x):
150150
151151
- A name.
152152
- A table, the embedding table to use.
153-
- An input shape (per replica).
154-
- An output shape (per replica).
153+
- An input shape (batch size is for all TPUs).
154+
- An output shape (batch size is for all TPUs).
155155
156156
We can organize features in any structure we want, which can be nested. A dict
157157
is often a good choice to have names for the inputs and outputs.
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
"""
2+
Title: DistributedEmbedding using TPU SparseCore and TensorFlow
3+
Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
4+
Date created: 2025/09/02
5+
Last modified: 2025/09/02
6+
Description: Rank movies using a two tower model with embeddings on SparseCore.
7+
Accelerator: TPU
8+
"""
9+
10+
"""
11+
## Introduction
12+
13+
In the [basic ranking](/keras_rs/examples/basic_ranking/) tutorial, we showed
14+
how to build a ranking model for the MovieLens dataset to suggest movies to
15+
users.
16+
17+
This tutorial implements the same model trained on the same dataset but with the
18+
use of `keras_rs.layers.DistributedEmbedding`, which makes use of SparseCore on
19+
TPU. This is the TensorFlow version of the tutorial. It needs to be run on TPU
20+
v5p or v6e.
21+
22+
Let's begin by installing the necessary libraries. Note that we need
23+
`tensorflow-tpu` version 2.19. We'll also install `keras-rs`.
24+
"""
25+
26+
"""shell
27+
pip install -U -q tensorflow-tpu==2.19.1
28+
pip install -q keras-rs
29+
"""
30+
31+
"""
32+
We're using the PJRT version of the runtime for TensorFlow. We're also enabling
33+
the MLIR bridge. This requires setting a few flags before importing TensorFlow.
34+
"""
35+
36+
import os
37+
import libtpu
38+
39+
os.environ["PJRT_DEVICE"] = "TPU"
40+
os.environ["NEXT_PLUGGABLE_DEVICE_USE_C_API"] = "true"
41+
os.environ["TF_PLUGGABLE_DEVICE_LIBRARY_PATH"] = libtpu.get_library_path()
42+
os.environ["TF_XLA_FLAGS"] = (
43+
"--tf_mlir_enable_mlir_bridge=true "
44+
"--tf_mlir_enable_convert_control_to_data_outputs_pass=true "
45+
"--tf_mlir_enable_merge_control_flow_pass=true"
46+
)
47+
48+
import tensorflow as tf
49+
50+
"""
51+
We now set the Keras backend to TensorFlow and import the necessary libraries.
52+
"""
53+
54+
os.environ["KERAS_BACKEND"] = "tensorflow"
55+
56+
import keras
57+
import keras_rs
58+
import tensorflow_datasets as tfds
59+
60+
"""
61+
## Creating a `TPUStrategy`
62+
63+
To run TensorFlow on TPU, you need to use a
64+
[`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy)
65+
to handle the distribution of the model.
66+
67+
The core of the model is replicated across TPU instances, which is done by the
68+
`TPUStrategy`. Note that on GPU you would use
69+
[`tf.distribute.MirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy)
70+
instead, but this strategy is not for TPU.
71+
72+
Only the embedding tables handled by `DistributedEmbedding` are sharded across
73+
the SparseCore chips of all the available TPUs.
74+
"""
75+
76+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
77+
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
78+
tpu_metadata = resolver.get_tpu_system_metadata()
79+
80+
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
81+
topology, num_replicas=tpu_metadata.num_cores
82+
)
83+
strategy = tf.distribute.TPUStrategy(
84+
resolver, experimental_device_assignment=device_assignment
85+
)
86+
87+
"""
88+
## Dataset distribution
89+
90+
While the model is replicated and the embedding tables are sharded across
91+
SparseCores, the dataset is distributed by sharding each batch across the TPUs.
92+
We need to make sure the batch size is a multiple of the number of TPUs.
93+
"""
94+
95+
PER_REPLICA_BATCH_SIZE = 256
96+
BATCH_SIZE = PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync
97+
98+
"""
99+
## Preparing the dataset
100+
101+
We're going to use the same MovieLens data. The ratings are the objectives we
102+
are trying to predict.
103+
"""
104+
105+
# Ratings data.
106+
ratings = tfds.load("movielens/100k-ratings", split="train")
107+
# Features of all the available movies.
108+
movies = tfds.load("movielens/100k-movies", split="train")
109+
110+
"""
111+
We need to know the number of users as we're using the user ID directly as an
112+
index in the user embedding table.
113+
"""
114+
115+
users_count = int(
116+
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
117+
.reduce(tf.constant(0, tf.int32), tf.maximum)
118+
.numpy()
119+
)
120+
121+
"""
122+
We also need do know the number of movies as we're using the movie ID directly
123+
as an index in the movie embedding table.
124+
"""
125+
126+
movies_count = int(movies.cardinality().numpy())
127+
128+
"""
129+
The inputs to the model are the user IDs and movie IDs and the labels are the
130+
ratings.
131+
"""
132+
133+
134+
def preprocess_rating(x):
135+
return (
136+
# Inputs are user IDs and movie IDs
137+
{
138+
"user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32),
139+
"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
140+
},
141+
# Labels are ratings between 0 and 1.
142+
(x["user_rating"] - 1.0) / 4.0,
143+
)
144+
145+
146+
"""
147+
We'll split the data by putting 80% of the ratings in the train set, and 20% in
148+
the test set.
149+
"""
150+
151+
shuffled_ratings = ratings.map(preprocess_rating).shuffle(
152+
100_000, seed=42, reshuffle_each_iteration=False
153+
)
154+
train_ratings = (
155+
shuffled_ratings.take(80_000).batch(BATCH_SIZE, drop_remainder=True).cache()
156+
)
157+
test_ratings = (
158+
shuffled_ratings.skip(80_000)
159+
.take(20_000)
160+
.batch(BATCH_SIZE, drop_remainder=True)
161+
.cache()
162+
)
163+
164+
"""
165+
## Configuring DistributedEmbedding
166+
167+
The `keras_rs.layers.DistributedEmbedding` handles multiple features and
168+
multiple embedding tables. This is to enable the sharing of tables between
169+
features and allow some optimizations that come from combining multiple
170+
embedding lookups into a single invocation. In this section, we'll describe
171+
how to configure these.
172+
173+
### Configuring tables
174+
175+
Tables are configured using `keras_rs.layers.TableConfig`, which has:
176+
177+
- A name.
178+
- A vocabulary size (input size).
179+
- an embedding dimension (output size).
180+
- A combiner to specify how to reduce multiple embeddings into a single one in
181+
the case when we embed a sequence. Note that this doesn't apply to our example
182+
because we're getting a single embedding for each user and each movie.
183+
- A placement to tell whether to put the table on the SparseCore chips or not.
184+
In this case, we want the `"sparsecore"` placement.
185+
- An optimizer to specify how to apply gradients when training. Each table has
186+
its own optimizer and the one passed to `model.compile()` is not used for the
187+
embedding tables.
188+
189+
### Configuring features
190+
191+
Features are configured using `keras_rs.layers.FeatureConfig`, which has:
192+
193+
- A name.
194+
- A table, the embedding table to use.
195+
- An input shape (batch size is for all TPUs).
196+
- An output shape (batch size is for all TPUs).
197+
198+
We can organize features in any structure we want, which can be nested. A dict
199+
is often a good choice to have names for the inputs and outputs.
200+
"""
201+
202+
EMBEDDING_DIMENSION = 32
203+
204+
movie_table = keras_rs.layers.TableConfig(
205+
name="movie_table",
206+
vocabulary_size=movies_count + 1, # +1 for movie ID 0, which is not used
207+
embedding_dim=EMBEDDING_DIMENSION,
208+
optimizer="adam",
209+
placement="sparsecore",
210+
)
211+
user_table = keras_rs.layers.TableConfig(
212+
name="user_table",
213+
vocabulary_size=users_count + 1, # +1 for user ID 0, which is not used
214+
embedding_dim=EMBEDDING_DIMENSION,
215+
optimizer="adam",
216+
placement="sparsecore",
217+
)
218+
219+
FEATURE_CONFIGS = {
220+
"movie_id": keras_rs.layers.FeatureConfig(
221+
name="movie",
222+
table=movie_table,
223+
input_shape=(BATCH_SIZE,),
224+
output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
225+
),
226+
"user_id": keras_rs.layers.FeatureConfig(
227+
name="user",
228+
table=user_table,
229+
input_shape=(BATCH_SIZE,),
230+
output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
231+
),
232+
}
233+
234+
"""
235+
## Defining the Model
236+
237+
We're now ready to create a `DistributedEmbedding` inside a model. Once we have
238+
the configuration, we simply pass it the constructor of `DistributedEmbedding`.
239+
Then, within the model `call` method, `DistributedEmbedding` is the first layer
240+
we call.
241+
242+
The ouputs have the exact same structure as the inputs. In our example, we
243+
concatenate the embeddings we got as outputs and run them through a tower of
244+
dense layers.
245+
"""
246+
247+
248+
class EmbeddingModel(keras.Model):
249+
"""Create the model with the embedding configuration.
250+
251+
Args:
252+
feature_configs: the configuration for `DistributedEmbedding`.
253+
"""
254+
255+
def __init__(self, feature_configs):
256+
super().__init__()
257+
258+
self.embedding_layer = keras_rs.layers.DistributedEmbedding(
259+
feature_configs=feature_configs
260+
)
261+
self.ratings = keras.Sequential(
262+
[
263+
# Learn multiple dense layers.
264+
keras.layers.Dense(256, activation="relu"),
265+
keras.layers.Dense(64, activation="relu"),
266+
# Make rating predictions in the final layer.
267+
keras.layers.Dense(1),
268+
]
269+
)
270+
271+
def call(self, features):
272+
# Embedding lookup. Outputs have the same structure as the inputs.
273+
embedding = self.embedding_layer(features)
274+
return self.ratings(
275+
keras.ops.concatenate(
276+
[embedding["user_id"], embedding["movie_id"]],
277+
axis=1,
278+
)
279+
)
280+
281+
282+
"""
283+
Let's now instantiate the model. We then use `model.compile()` to configure the
284+
loss, metrics and optimizer. Again, this Adagrad optimizer will only apply to
285+
the dense layers and not the embedding tables.
286+
"""
287+
288+
with strategy.scope():
289+
model = EmbeddingModel(FEATURE_CONFIGS)
290+
291+
model.compile(
292+
loss=keras.losses.MeanSquaredError(),
293+
metrics=[keras.metrics.RootMeanSquaredError()],
294+
optimizer="adagrad",
295+
)
296+
297+
"""
298+
## Fitting and evaluating
299+
300+
We can use the standard Keras `model.fit()` to train the model. Keras will
301+
automatically use the `TPUStrategy` to distribute the model and the data.
302+
"""
303+
304+
with strategy.scope():
305+
model.fit(train_ratings, epochs=5)
306+
307+
"""
308+
Same for `model.evaluate()`.
309+
"""
310+
311+
with strategy.scope():
312+
model.evaluate(test_ratings, return_dict=True)
313+
314+
"""
315+
That's it.
316+
317+
This example shows that after setting up the `TPUStrategy` and configuring the
318+
`DistributedEmbedding`, you can use the standard Keras workflows.
319+
"""

examples/keras_rs/ipynb/distributed_embedding_jax.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
"source": [
99
"# DistributedEmbedding using TPU SparseCore and JAX\n",
1010
"\n",
11-
"**Author:** [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)<br>\n",
11+
"**Author:** [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/), [C. Antonio Sánchez](https://github.com/cantonios/)<br>\n",
1212
"**Date created:** 2025/06/03<br>\n",
13-
"**Last modified:** 2025/06/03<br>\n",
13+
"**Last modified:** 2025/09/02<br>\n",
1414
"**Description:** Rank movies using a two tower model with embeddings on SparseCore."
1515
]
1616
},
@@ -103,7 +103,7 @@
103103
"source": [
104104
"## Preparing the dataset\n",
105105
"\n",
106-
"We're going to use the same Movielens data. The ratings are the objectives we\n",
106+
"We're going to use the same MovieLens data. The ratings are the objectives we\n",
107107
"are trying to predict."
108108
]
109109
},
@@ -267,8 +267,8 @@
267267
"\n",
268268
"- A name.\n",
269269
"- A table, the embedding table to use.\n",
270-
"- An input shape (per replica).\n",
271-
"- An output shape (per replica).\n",
270+
"- An input shape (batch size is for all TPUs).\n",
271+
"- An output shape (batch size is for all TPUs).\n",
272272
"\n",
273273
"We can organize features in any structure we want, which can be nested. A dict\n",
274274
"is often a good choice to have names for the inputs and outputs."

0 commit comments

Comments
 (0)