Skip to content

Commit 4a0cbbf

Browse files
authored
DistributedEmbedding for JAX example needs more recent JAX version. (#2179)
The preinstalled version of JAX on colab no longer works with the latest version of the `jax-tpu-embedding`.
1 parent 0af847b commit 4a0cbbf

File tree

3 files changed

+3
-0
lines changed

3 files changed

+3
-0
lines changed

examples/keras_rs/distributed_embedding_jax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"""
2525

2626
"""shell
27+
pip install -q -U jax[tpu]>=0.7.0
2728
pip install -q jax-tpu-embedding
2829
pip install -q tensorflow-cpu
2930
pip install -q keras-rs

examples/keras_rs/ipynb/distributed_embedding_jax.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
},
4444
"outputs": [],
4545
"source": [
46+
"!pip install -q -U jax[tpu]>=0.7.0\n",
4647
"!pip install -q jax-tpu-embedding\n",
4748
"!pip install -q tensorflow-cpu\n",
4849
"!pip install -q keras-rs"

examples/keras_rs/md/distributed_embedding_jax.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ libraries.
2727

2828

2929
```python
30+
!pip install -q -U jax[tpu]>=0.7.0
3031
!pip install -q jax-tpu-embedding
3132
!pip install -q tensorflow-cpu
3233
!pip install -q keras-rs

0 commit comments

Comments
 (0)