Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 55 additions & 19 deletions keras_rs/src/layers/embedding/base_distributed_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dataclasses
import importlib.util
import typing
import warnings
from typing import Any, Sequence

import keras
Expand Down Expand Up @@ -552,17 +553,17 @@ def _init_feature_configs_structures(
] = {}

# Lazily initialized.
has_sparsecore = None
has_sparsecores = None

for path, feature_config in paths_and_feature_configs:
if isinstance(feature_config, FeatureConfig):
placement = feature_config.table.placement
# Resolve "auto" to an actual placement.
if placement == "auto":
if has_sparsecore is None:
has_sparsecore = self._has_sparsecore()
if has_sparsecores is None:
has_sparsecores = self.has_sparsecores()
placement = (
"sparsecore" if has_sparsecore else "default_device"
"sparsecore" if has_sparsecores else "default_device"
)
else:
# It's a `tf.tpu.experimental.embedding.FeatureConfig`.
Expand Down Expand Up @@ -936,24 +937,53 @@ def _default_device_get_embedding_tables(self) -> dict[str, types.Tensor]:
)
return tables

def _has_sparsecore(self) -> bool:
@classmethod
def has_sparsecores(cls) -> bool:
"""Return whether the current devices are TPUs with SparseCore chips.

This is a class method and can be invoked before instantiating a
`DistributedEmbedding`.

Returns:
True if devices are TPUs with SparseCore chips.

Example:

```python
if keras_rs.layers.DistributedEmbedding.has_sparsecores():
print("We have SparseCores")
```
"""
# Explicitly check for SparseCore availability.
# We need this check here rather than in jax/distributed_embedding.py
# so that we can warn the user about missing dependencies.
if keras.backend.backend() == "jax":
# Check if SparseCores are available.
try:
import jax
import jax

tpu_devices = jax.devices("tpu")
except RuntimeError:
if jax.default_backend() != "tpu":
# No TPUs available.
return False

if len(tpu_devices) > 0:
device_kind = tpu_devices[0].device_kind
if device_kind in ["TPU v5", "TPU v6 lite"]:
return True
if importlib.util.find_spec("jax_tpu_embedding") is None:
# jax-tpu-embedding is missing and we're on TPU, warn.
warnings.warn(
"Using DistributedEmbedding on TPU without the "
"jax-tpu-embedding module installed. The SparseCores will "
"not be detected and a placement of `auto` will place the "
"tables on TensorCore, which can affect performance. "
"Install the module via `pip install jax-tpu-embedding`"
)
return False

# Rely on jax_tpu_embedding's `num_sparsecores_per_device`.
from jax_tpu_embedding.sparsecore.utils import utils as jte_utils

try:
return jte_utils.num_sparsecores_per_device() > 0 # type: ignore[no-any-return]
except ValueError:
# `num_sparsecores_per_device` raises if there is no SparseCore.
return False

return False

Expand All @@ -965,13 +995,19 @@ def _sparsecore_init(
del feature_configs, table_stacking

if keras.backend.backend() == "jax":
jax_tpu_embedding_spec = importlib.util.find_spec(
"jax_tpu_embedding"
)
if jax_tpu_embedding_spec is None:
import jax

if jax.default_backend() != "tpu":
raise ValueError(
"The `sparsecore` placement is not available on "
f"{jax.default_backend()}"
)

if importlib.util.find_spec("jax_tpu_embedding") is None:
raise ImportError(
"Please install jax-tpu-embedding to use "
"DistributedEmbedding on sparsecore devices."
"DistributedEmbedding on TPU with SparseCore chips "
"requires the jax-tpu-embedding module. You can install it "
"via `pip install jax-tpu-embedding`"
)

raise self._unsupported_placement_error("sparsecore")
Expand Down
15 changes: 12 additions & 3 deletions keras_rs/src/layers/embedding/jax/distributed_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
if jax.__version_info__ >= (0, 8, 0):
from jax import shard_map
else:
from jax.experimental.shard_map import shard_map # type: ignore[assignment]
from jax.experimental.shard_map import ( # type: ignore[assignment, no-redef]
shard_map,
)


ArrayLike = Union[np.ndarray[Any, Any], jax.Array]
Expand Down Expand Up @@ -348,9 +350,16 @@ def _sparsecore_init(
feature_configs: dict[str, FeatureConfig],
table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
) -> None:
if not self._has_sparsecore():
if not self.has_sparsecores():
if jax.default_backend() != "tpu":
# If jax-tpu-embedding is installed, but running on GPU or CPU.
raise ValueError(
"The `sparsecore` placement is not available on "
f"{jax.default_backend()}"
)

raise ValueError(
"Not sparse cores available, cannot use explicit sparsecore"
"No SparseCores available, cannot use explicit `sparsecore`"
" placement."
)

Expand Down
26 changes: 9 additions & 17 deletions keras_rs/src/layers/embedding/jax/distributed_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,9 @@ def _create_sparsecore_layout(
return sparsecore_layout


def _num_sparsecores_per_device() -> int:
if test_utils.has_sparsecores():
return jte_utils.num_sparsecores_per_device()

# Default to one for non-sparsecore tests.
return 1


@pytest.mark.skipif(
keras.backend.backend() != "jax",
reason="Backend specific test",
not jax_distributed_embedding.DistributedEmbedding.has_sparsecores(),
reason="Requires SparseCores",
)
class ShardedInitializerTest(parameterized.TestCase):
@parameterized.product(
Expand Down Expand Up @@ -103,8 +95,8 @@ def test_wrap_and_call(


@pytest.mark.skipif(
keras.backend.backend() != "jax",
reason="Backend specific test",
not jax_distributed_embedding.DistributedEmbedding.has_sparsecores(),
reason="Requires SparseCores",
)
class StackedTableInitializerTest(parameterized.TestCase):
def test_sharded_matches_unsharded(self):
Expand All @@ -130,7 +122,7 @@ def test_sharded_matches_unsharded(self):
)

device_count = jax.device_count()
num_sc_per_device = _num_sparsecores_per_device()
num_sc_per_device = jte_utils.num_sparsecores_per_device()
num_table_shards = device_count * num_sc_per_device

# Convert to JAX and stack tables.
Expand Down Expand Up @@ -208,7 +200,7 @@ def test_random_shards(self):
}

device_count = jax.device_count()
num_sc_per_device = _num_sparsecores_per_device()
num_sc_per_device = jte_utils.num_sparsecores_per_device()
num_table_shards = device_count * num_sc_per_device

table_stacking_lib.stack_tables(
Expand Down Expand Up @@ -267,7 +259,7 @@ def test_compilability(self):
}

device_count = jax.device_count()
num_sc_per_device = _num_sparsecores_per_device()
num_sc_per_device = jte_utils.num_sparsecores_per_device()
num_table_shards = device_count * num_sc_per_device

table_stacking_lib.stack_tables(
Expand Down Expand Up @@ -305,8 +297,8 @@ def my_initializer(shape: tuple[int, int], dtype: Any):


@pytest.mark.skipif(
keras.backend.backend() != "jax",
reason="Backend specific test",
not jax_distributed_embedding.DistributedEmbedding.has_sparsecores(),
reason="Requires SparseCores",
)
class DistributedEmbeddingLayerTest(parameterized.TestCase):
@parameterized.product(
Expand Down
2 changes: 1 addition & 1 deletion keras_rs/src/layers/embedding/jax/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
else:
from jax.experimental.shard_map import shard_map as exp_shard_map

def shard_map( # type: ignore[misc]
def shard_map( # type: ignore[misc, no-redef]
f: Any = None,
/,
*,
Expand Down
9 changes: 5 additions & 4 deletions keras_rs/src/layers/embedding/jax/embedding_lookup_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from jax_tpu_embedding.sparsecore.lib.nn import table_stacking
from jax_tpu_embedding.sparsecore.utils import utils as jte_utils

from keras_rs.src.layers.embedding.jax import distributed_embedding
from keras_rs.src.layers.embedding.jax import embedding_lookup
from keras_rs.src.layers.embedding.jax import embedding_utils
from keras_rs.src.layers.embedding.jax import test_utils
Expand Down Expand Up @@ -133,7 +134,7 @@ def _create_table_and_feature_specs(
stacked=[True, False],
)
def test_forward_pass(self, ragged: bool, stacked: bool):
if not test_utils.has_sparsecores():
if not distributed_embedding.DistributedEmbedding.has_sparsecores():
self.skipTest("Test requires sparsecores.")

devices = jax.devices()
Expand Down Expand Up @@ -215,7 +216,7 @@ def test_forward_pass(self, ragged: bool, stacked: bool):
def test_model_sharding(
self, ragged: bool, stacked: bool, num_model_shards: int
):
if not test_utils.has_sparsecores():
if not distributed_embedding.DistributedEmbedding.has_sparsecores():
self.skipTest("Test requires sparsecores.")

if num_model_shards > jax.device_count():
Expand Down Expand Up @@ -319,7 +320,7 @@ def test_backward_pass(
stacked: bool,
optimizer: embedding_spec.OptimizerSpec,
):
if not test_utils.has_sparsecores():
if not distributed_embedding.DistributedEmbedding.has_sparsecores():
self.skipTest("Test requires sparsecores.")

devices = jax.devices()
Expand Down Expand Up @@ -426,7 +427,7 @@ def test_autograd(
stacked: bool,
optimizer: embedding_spec.OptimizerSpec,
):
if not test_utils.has_sparsecores():
if not distributed_embedding.DistributedEmbedding.has_sparsecores():
self.skipTest("Test requires sparsecores.")

devices = jax.devices()
Expand Down
7 changes: 0 additions & 7 deletions keras_rs/src/layers/embedding/jax/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@
Shape: TypeAlias = tuple[int, ...]


def has_sparsecores() -> bool:
device_kind = jax.devices()[0].device_kind
if device_kind in ["TPU v5", "TPU v6 lite"]:
return True
return False


def _round_up_to_multiple(value: int, multiple: int) -> int:
return ((value + multiple - 1) // multiple) * multiple

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,17 @@ def __init__(
feature_configs, table_stacking=table_stacking, **kwargs
)

def _is_tpu_strategy(self, strategy: tf.distribute.Strategy) -> bool:
@classmethod
def _is_tpu_strategy(cls, strategy: tf.distribute.Strategy) -> bool:
return isinstance(
strategy,
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy),
)

def _has_sparsecore(self) -> bool:
@classmethod
def has_sparsecores(cls) -> bool:
strategy = tf.distribute.get_strategy()
if self._is_tpu_strategy(strategy):
if cls._is_tpu_strategy(strategy):
tpu_embedding_feature = (
strategy.extended.tpu_hardware_feature.embedding_feature
)
Expand Down
Loading