Skip to content

Commit a7147e0

Browse files
committed
Add support for DistributedEmbedding for Ironwood and expose has_sparsecores.
- Ironwood is identified as `TPUv7x`. - Also added `TPU v5p`. - Added public class method `DistributedEmbedding.has_sparsecores`. - Removed duplicate implementation of `has_sparsecores` in `test_utils`.
1 parent f9be2ec commit a7147e0

File tree

6 files changed

+37
-25
lines changed

6 files changed

+37
-25
lines changed

keras_rs/src/layers/embedding/base_distributed_embedding.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -552,17 +552,17 @@ def _init_feature_configs_structures(
552552
] = {}
553553

554554
# Lazily initialized.
555-
has_sparsecore = None
555+
has_sparsecores = None
556556

557557
for path, feature_config in paths_and_feature_configs:
558558
if isinstance(feature_config, FeatureConfig):
559559
placement = feature_config.table.placement
560560
# Resolve "auto" to an actual placement.
561561
if placement == "auto":
562-
if has_sparsecore is None:
563-
has_sparsecore = self._has_sparsecore()
562+
if has_sparsecores is None:
563+
has_sparsecores = self.has_sparsecores()
564564
placement = (
565-
"sparsecore" if has_sparsecore else "default_device"
565+
"sparsecore" if has_sparsecores else "default_device"
566566
)
567567
else:
568568
# It's a `tf.tpu.experimental.embedding.FeatureConfig`.
@@ -936,7 +936,23 @@ def _default_device_get_embedding_tables(self) -> dict[str, types.Tensor]:
936936
)
937937
return tables
938938

939-
def _has_sparsecore(self) -> bool:
939+
@classmethod
940+
def has_sparsecores(cls) -> bool:
941+
"""Return whether the current devices are TPUs with SparseCore chips.
942+
943+
This is a class method and can be invoked before instantiating a
944+
`DistributedEmbedding`.
945+
946+
Returns:
947+
True if devices are TPUs with SparseCore chips.
948+
949+
Example:
950+
951+
```python
952+
if keras_rs.layers.DistributedEmbedding.has_sparsecores():
953+
print("We have SparseCores")
954+
```
955+
"""
940956
# Explicitly check for SparseCore availability.
941957
# We need this check here rather than in jax/distributed_embedding.py
942958
# so that we can warn the user about missing dependencies.
@@ -952,7 +968,7 @@ def _has_sparsecore(self) -> bool:
952968

953969
if len(tpu_devices) > 0:
954970
device_kind = tpu_devices[0].device_kind
955-
if device_kind in ["TPU v5", "TPU v6 lite"]:
971+
if device_kind in ["TPU v5", "TPU v5p", "TPU v6 lite", "TPU7x"]:
956972
return True
957973

958974
return False

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def _sparsecore_init(
348348
feature_configs: dict[str, FeatureConfig],
349349
table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
350350
) -> None:
351-
if not self._has_sparsecore():
351+
if not self.has_sparsecores():
352352
raise ValueError(
353353
"Not sparse cores available, cannot use explicit sparsecore"
354354
" placement."

keras_rs/src/layers/embedding/jax/distributed_embedding_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ def _create_sparsecore_layout(
6161

6262

6363
def _num_sparsecores_per_device() -> int:
64-
if test_utils.has_sparsecores():
64+
try:
6565
return jte_utils.num_sparsecores_per_device()
66-
67-
# Default to one for non-sparsecore tests.
68-
return 1
66+
except ValueError:
67+
# Default to one for non-sparsecore tests.
68+
return 1
6969

7070

7171
@pytest.mark.skipif(

keras_rs/src/layers/embedding/jax/embedding_lookup_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from jax_tpu_embedding.sparsecore.lib.nn import table_stacking
1515
from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
1616

17+
from keras_rs.src.layers.embedding.jax import distributed_embedding
1718
from keras_rs.src.layers.embedding.jax import embedding_lookup
1819
from keras_rs.src.layers.embedding.jax import embedding_utils
1920
from keras_rs.src.layers.embedding.jax import test_utils
@@ -133,7 +134,7 @@ def _create_table_and_feature_specs(
133134
stacked=[True, False],
134135
)
135136
def test_forward_pass(self, ragged: bool, stacked: bool):
136-
if not test_utils.has_sparsecores():
137+
if not distributed_embedding.DistributedEmbedding.has_sparsecores():
137138
self.skipTest("Test requires sparsecores.")
138139

139140
devices = jax.devices()
@@ -215,7 +216,7 @@ def test_forward_pass(self, ragged: bool, stacked: bool):
215216
def test_model_sharding(
216217
self, ragged: bool, stacked: bool, num_model_shards: int
217218
):
218-
if not test_utils.has_sparsecores():
219+
if not distributed_embedding.DistributedEmbedding.has_sparsecores():
219220
self.skipTest("Test requires sparsecores.")
220221

221222
if num_model_shards > jax.device_count():
@@ -319,7 +320,7 @@ def test_backward_pass(
319320
stacked: bool,
320321
optimizer: embedding_spec.OptimizerSpec,
321322
):
322-
if not test_utils.has_sparsecores():
323+
if not distributed_embedding.DistributedEmbedding.has_sparsecores():
323324
self.skipTest("Test requires sparsecores.")
324325

325326
devices = jax.devices()
@@ -426,7 +427,7 @@ def test_autograd(
426427
stacked: bool,
427428
optimizer: embedding_spec.OptimizerSpec,
428429
):
429-
if not test_utils.has_sparsecores():
430+
if not distributed_embedding.DistributedEmbedding.has_sparsecores():
430431
self.skipTest("Test requires sparsecores.")
431432

432433
devices = jax.devices()

keras_rs/src/layers/embedding/jax/test_utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,6 @@
2020
Shape: TypeAlias = tuple[int, ...]
2121

2222

23-
def has_sparsecores() -> bool:
24-
device_kind = jax.devices()[0].device_kind
25-
if device_kind in ["TPU v5", "TPU v6 lite"]:
26-
return True
27-
return False
28-
29-
3023
def _round_up_to_multiple(value: int, multiple: int) -> int:
3124
return ((value + multiple - 1) // multiple) * multiple
3225

keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,17 @@ def __init__(
6161
feature_configs, table_stacking=table_stacking, **kwargs
6262
)
6363

64-
def _is_tpu_strategy(self, strategy: tf.distribute.Strategy) -> bool:
64+
@classmethod
65+
def _is_tpu_strategy(cls, strategy: tf.distribute.Strategy) -> bool:
6566
return isinstance(
6667
strategy,
6768
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy),
6869
)
6970

70-
def _has_sparsecore(self) -> bool:
71+
@classmethod
72+
def has_sparsecores(cls) -> bool:
7173
strategy = tf.distribute.get_strategy()
72-
if self._is_tpu_strategy(strategy):
74+
if cls._is_tpu_strategy(strategy):
7375
tpu_embedding_feature = (
7476
strategy.extended.tpu_hardware_feature.embedding_feature
7577
)

0 commit comments

Comments
 (0)