Skip to content

Commit fd56e87

Browse files
committed
Add support for DistributedEmbedding for Ironwood and expose has_sparsecores.
- Use `num_sparsecores_per_device` from `jax-tpu-embedding` instead of having a duplicated hardcoded list of supported TPUs. - Added public class method `DistributedEmbedding.has_sparsecores`. - Added warning when running JAX on TPU with `jax-tpu-embedding` not installed. - Made error messages more specific and consistent with Keras errors.
1 parent f9be2ec commit fd56e87

File tree

6 files changed

+83
-52
lines changed

6 files changed

+83
-52
lines changed

keras_rs/src/layers/embedding/base_distributed_embedding.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import dataclasses
33
import importlib.util
44
import typing
5+
import warnings
56
from typing import Any, Sequence
67

78
import keras
@@ -552,17 +553,17 @@ def _init_feature_configs_structures(
552553
] = {}
553554

554555
# Lazily initialized.
555-
has_sparsecore = None
556+
has_sparsecores = None
556557

557558
for path, feature_config in paths_and_feature_configs:
558559
if isinstance(feature_config, FeatureConfig):
559560
placement = feature_config.table.placement
560561
# Resolve "auto" to an actual placement.
561562
if placement == "auto":
562-
if has_sparsecore is None:
563-
has_sparsecore = self._has_sparsecore()
563+
if has_sparsecores is None:
564+
has_sparsecores = self.has_sparsecores()
564565
placement = (
565-
"sparsecore" if has_sparsecore else "default_device"
566+
"sparsecore" if has_sparsecores else "default_device"
566567
)
567568
else:
568569
# It's a `tf.tpu.experimental.embedding.FeatureConfig`.
@@ -936,24 +937,53 @@ def _default_device_get_embedding_tables(self) -> dict[str, types.Tensor]:
936937
)
937938
return tables
938939

939-
def _has_sparsecore(self) -> bool:
940+
@classmethod
941+
def has_sparsecores(cls) -> bool:
942+
"""Return whether the current devices are TPUs with SparseCore chips.
943+
944+
This is a class method and can be invoked before instantiating a
945+
`DistributedEmbedding`.
946+
947+
Returns:
948+
True if devices are TPUs with SparseCore chips.
949+
950+
Example:
951+
952+
```python
953+
if keras_rs.layers.DistributedEmbedding.has_sparsecores():
954+
print("We have SparseCores")
955+
```
956+
"""
940957
# Explicitly check for SparseCore availability.
941958
# We need this check here rather than in jax/distributed_embedding.py
942959
# so that we can warn the user about missing dependencies.
943960
if keras.backend.backend() == "jax":
944961
# Check if SparseCores are available.
945-
try:
946-
import jax
962+
import jax
947963

948-
tpu_devices = jax.devices("tpu")
949-
except RuntimeError:
964+
if jax.default_backend() != "tpu":
950965
# No TPUs available.
951966
return False
952967

953-
if len(tpu_devices) > 0:
954-
device_kind = tpu_devices[0].device_kind
955-
if device_kind in ["TPU v5", "TPU v6 lite"]:
956-
return True
968+
if importlib.util.find_spec("jax_tpu_embedding") is None:
969+
# jax-tpu-embedding is missing and we're on TPU, warn.
970+
warnings.warn(
971+
"Using DistributedEmbedding on TPU without the "
972+
"jax-tpu-embedding module installed. The SparseCores will "
973+
"not be detected and a placement of `auto` will place the "
974+
"tables on TensorCore, which can affect performance. "
975+
"Install the module via `pip install jax-tpu-embedding`"
976+
)
977+
return False
978+
979+
# Rely on jax_tpu_embedding's `num_sparsecores_per_device`.
980+
from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
981+
982+
try:
983+
return jte_utils.num_sparsecores_per_device() > 0 # type: ignore[no-any-return]
984+
except ValueError:
985+
# `num_sparsecores_per_device` raises if there is no SparseCore.
986+
return False
957987

958988
return False
959989

@@ -965,13 +995,19 @@ def _sparsecore_init(
965995
del feature_configs, table_stacking
966996

967997
if keras.backend.backend() == "jax":
968-
jax_tpu_embedding_spec = importlib.util.find_spec(
969-
"jax_tpu_embedding"
970-
)
971-
if jax_tpu_embedding_spec is None:
998+
import jax
999+
1000+
if jax.default_backend() != "tpu":
1001+
raise ValueError(
1002+
"The `sparsecore` placement is not available on "
1003+
f"{jax.default_backend()}"
1004+
)
1005+
1006+
if importlib.util.find_spec("jax_tpu_embedding") is None:
9721007
raise ImportError(
973-
"Please install jax-tpu-embedding to use "
974-
"DistributedEmbedding on sparsecore devices."
1008+
"DistributedEmbedding on TPU with SparseCore chips "
1009+
"requires the jax-tpu-embedding module. You can install it "
1010+
"via `pip install jax-tpu-embedding`"
9751011
)
9761012

9771013
raise self._unsupported_placement_error("sparsecore")

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,16 @@ def _sparsecore_init(
348348
feature_configs: dict[str, FeatureConfig],
349349
table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
350350
) -> None:
351-
if not self._has_sparsecore():
351+
if not self.has_sparsecores():
352+
if jax.default_backend() != "tpu":
353+
# If jax-tpu-embedding is installed, but running on GPU or CPU.
354+
raise ValueError(
355+
"The `sparsecore` placement is not available on "
356+
f"{jax.default_backend()}"
357+
)
358+
352359
raise ValueError(
353-
"Not sparse cores available, cannot use explicit sparsecore"
360+
"No SparseCores available, cannot use explicit `sparsecore`"
354361
" placement."
355362
)
356363

keras_rs/src/layers/embedding/jax/distributed_embedding_test.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,9 @@ def _create_sparsecore_layout(
6060
return sparsecore_layout
6161

6262

63-
def _num_sparsecores_per_device() -> int:
64-
if test_utils.has_sparsecores():
65-
return jte_utils.num_sparsecores_per_device()
66-
67-
# Default to one for non-sparsecore tests.
68-
return 1
69-
70-
7163
@pytest.mark.skipif(
72-
keras.backend.backend() != "jax",
73-
reason="Backend specific test",
64+
not jax_distributed_embedding.DistributedEmbedding.has_sparsecores(),
65+
reason="Requires SparseCores",
7466
)
7567
class ShardedInitializerTest(parameterized.TestCase):
7668
@parameterized.product(
@@ -103,8 +95,8 @@ def test_wrap_and_call(
10395

10496

10597
@pytest.mark.skipif(
106-
keras.backend.backend() != "jax",
107-
reason="Backend specific test",
98+
not jax_distributed_embedding.DistributedEmbedding.has_sparsecores(),
99+
reason="Requires SparseCores",
108100
)
109101
class StackedTableInitializerTest(parameterized.TestCase):
110102
def test_sharded_matches_unsharded(self):
@@ -130,7 +122,7 @@ def test_sharded_matches_unsharded(self):
130122
)
131123

132124
device_count = jax.device_count()
133-
num_sc_per_device = _num_sparsecores_per_device()
125+
num_sc_per_device = jte_utils.num_sparsecores_per_device()
134126
num_table_shards = device_count * num_sc_per_device
135127

136128
# Convert to JAX and stack tables.
@@ -208,7 +200,7 @@ def test_random_shards(self):
208200
}
209201

210202
device_count = jax.device_count()
211-
num_sc_per_device = _num_sparsecores_per_device()
203+
num_sc_per_device = jte_utils.num_sparsecores_per_device()
212204
num_table_shards = device_count * num_sc_per_device
213205

214206
table_stacking_lib.stack_tables(
@@ -267,7 +259,7 @@ def test_compilability(self):
267259
}
268260

269261
device_count = jax.device_count()
270-
num_sc_per_device = _num_sparsecores_per_device()
262+
num_sc_per_device = jte_utils.num_sparsecores_per_device()
271263
num_table_shards = device_count * num_sc_per_device
272264

273265
table_stacking_lib.stack_tables(
@@ -305,8 +297,8 @@ def my_initializer(shape: tuple[int, int], dtype: Any):
305297

306298

307299
@pytest.mark.skipif(
308-
keras.backend.backend() != "jax",
309-
reason="Backend specific test",
300+
not jax_distributed_embedding.DistributedEmbedding.has_sparsecores(),
301+
reason="Requires SparseCores",
310302
)
311303
class DistributedEmbeddingLayerTest(parameterized.TestCase):
312304
@parameterized.product(

keras_rs/src/layers/embedding/jax/embedding_lookup_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from jax_tpu_embedding.sparsecore.lib.nn import table_stacking
1515
from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
1616

17+
from keras_rs.src.layers.embedding.jax import distributed_embedding
1718
from keras_rs.src.layers.embedding.jax import embedding_lookup
1819
from keras_rs.src.layers.embedding.jax import embedding_utils
1920
from keras_rs.src.layers.embedding.jax import test_utils
@@ -133,7 +134,7 @@ def _create_table_and_feature_specs(
133134
stacked=[True, False],
134135
)
135136
def test_forward_pass(self, ragged: bool, stacked: bool):
136-
if not test_utils.has_sparsecores():
137+
if not distributed_embedding.DistributedEmbedding.has_sparsecores():
137138
self.skipTest("Test requires sparsecores.")
138139

139140
devices = jax.devices()
@@ -215,7 +216,7 @@ def test_forward_pass(self, ragged: bool, stacked: bool):
215216
def test_model_sharding(
216217
self, ragged: bool, stacked: bool, num_model_shards: int
217218
):
218-
if not test_utils.has_sparsecores():
219+
if not distributed_embedding.DistributedEmbedding.has_sparsecores():
219220
self.skipTest("Test requires sparsecores.")
220221

221222
if num_model_shards > jax.device_count():
@@ -319,7 +320,7 @@ def test_backward_pass(
319320
stacked: bool,
320321
optimizer: embedding_spec.OptimizerSpec,
321322
):
322-
if not test_utils.has_sparsecores():
323+
if not distributed_embedding.DistributedEmbedding.has_sparsecores():
323324
self.skipTest("Test requires sparsecores.")
324325

325326
devices = jax.devices()
@@ -426,7 +427,7 @@ def test_autograd(
426427
stacked: bool,
427428
optimizer: embedding_spec.OptimizerSpec,
428429
):
429-
if not test_utils.has_sparsecores():
430+
if not distributed_embedding.DistributedEmbedding.has_sparsecores():
430431
self.skipTest("Test requires sparsecores.")
431432

432433
devices = jax.devices()

keras_rs/src/layers/embedding/jax/test_utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,6 @@
2020
Shape: TypeAlias = tuple[int, ...]
2121

2222

23-
def has_sparsecores() -> bool:
24-
device_kind = jax.devices()[0].device_kind
25-
if device_kind in ["TPU v5", "TPU v6 lite"]:
26-
return True
27-
return False
28-
29-
3023
def _round_up_to_multiple(value: int, multiple: int) -> int:
3124
return ((value + multiple - 1) // multiple) * multiple
3225

keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,17 @@ def __init__(
6161
feature_configs, table_stacking=table_stacking, **kwargs
6262
)
6363

64-
def _is_tpu_strategy(self, strategy: tf.distribute.Strategy) -> bool:
64+
@classmethod
65+
def _is_tpu_strategy(cls, strategy: tf.distribute.Strategy) -> bool:
6566
return isinstance(
6667
strategy,
6768
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy),
6869
)
6970

70-
def _has_sparsecore(self) -> bool:
71+
@classmethod
72+
def has_sparsecores(cls) -> bool:
7173
strategy = tf.distribute.get_strategy()
72-
if self._is_tpu_strategy(strategy):
74+
if cls._is_tpu_strategy(strategy):
7375
tpu_embedding_feature = (
7476
strategy.extended.tpu_hardware_feature.embedding_feature
7577
)

0 commit comments

Comments
 (0)