diff --git a/keras_rs/src/layers/embedding/base_distributed_embedding.py b/keras_rs/src/layers/embedding/base_distributed_embedding.py index 215131bb..c2c2d628 100644 --- a/keras_rs/src/layers/embedding/base_distributed_embedding.py +++ b/keras_rs/src/layers/embedding/base_distributed_embedding.py @@ -2,6 +2,7 @@ import dataclasses import importlib.util import typing +import warnings from typing import Any, Sequence import keras @@ -552,17 +553,17 @@ def _init_feature_configs_structures( ] = {} # Lazily initialized. - has_sparsecore = None + has_sparsecores = None for path, feature_config in paths_and_feature_configs: if isinstance(feature_config, FeatureConfig): placement = feature_config.table.placement # Resolve "auto" to an actual placement. if placement == "auto": - if has_sparsecore is None: - has_sparsecore = self._has_sparsecore() + if has_sparsecores is None: + has_sparsecores = self.has_sparsecores() placement = ( - "sparsecore" if has_sparsecore else "default_device" + "sparsecore" if has_sparsecores else "default_device" ) else: # It's a `tf.tpu.experimental.embedding.FeatureConfig`. @@ -936,24 +937,53 @@ def _default_device_get_embedding_tables(self) -> dict[str, types.Tensor]: ) return tables - def _has_sparsecore(self) -> bool: + @classmethod + def has_sparsecores(cls) -> bool: + """Return whether the current devices are TPUs with SparseCore chips. + + This is a class method and can be invoked before instantiating a + `DistributedEmbedding`. + + Returns: + True if devices are TPUs with SparseCore chips. + + Example: + + ```python + if keras_rs.layers.DistributedEmbedding.has_sparsecores(): + print("We have SparseCores") + ``` + """ # Explicitly check for SparseCore availability. # We need this check here rather than in jax/distributed_embedding.py # so that we can warn the user about missing dependencies. if keras.backend.backend() == "jax": # Check if SparseCores are available. - try: - import jax + import jax - tpu_devices = jax.devices("tpu") - except RuntimeError: + if jax.default_backend() != "tpu": # No TPUs available. return False - if len(tpu_devices) > 0: - device_kind = tpu_devices[0].device_kind - if device_kind in ["TPU v5", "TPU v6 lite"]: - return True + if importlib.util.find_spec("jax_tpu_embedding") is None: + # jax-tpu-embedding is missing and we're on TPU, warn. + warnings.warn( + "Using DistributedEmbedding on TPU without the " + "jax-tpu-embedding module installed. The SparseCores will " + "not be detected and a placement of `auto` will place the " + "tables on TensorCore, which can affect performance. " + "Install the module via `pip install jax-tpu-embedding`" + ) + return False + + # Rely on jax_tpu_embedding's `num_sparsecores_per_device`. + from jax_tpu_embedding.sparsecore.utils import utils as jte_utils + + try: + return jte_utils.num_sparsecores_per_device() > 0 # type: ignore[no-any-return] + except ValueError: + # `num_sparsecores_per_device` raises if there is no SparseCore. + return False return False @@ -965,13 +995,19 @@ def _sparsecore_init( del feature_configs, table_stacking if keras.backend.backend() == "jax": - jax_tpu_embedding_spec = importlib.util.find_spec( - "jax_tpu_embedding" - ) - if jax_tpu_embedding_spec is None: + import jax + + if jax.default_backend() != "tpu": + raise ValueError( + "The `sparsecore` placement is not available on " + f"{jax.default_backend()}" + ) + + if importlib.util.find_spec("jax_tpu_embedding") is None: raise ImportError( - "Please install jax-tpu-embedding to use " - "DistributedEmbedding on sparsecore devices." + "DistributedEmbedding on TPU with SparseCore chips " + "requires the jax-tpu-embedding module. You can install it " + "via `pip install jax-tpu-embedding`" ) raise self._unsupported_placement_error("sparsecore") diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 197b3042..9fd17768 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -31,7 +31,9 @@ if jax.__version_info__ >= (0, 8, 0): from jax import shard_map else: - from jax.experimental.shard_map import shard_map # type: ignore[assignment] + from jax.experimental.shard_map import ( # type: ignore[assignment, no-redef] + shard_map, + ) ArrayLike = Union[np.ndarray[Any, Any], jax.Array] @@ -348,9 +350,16 @@ def _sparsecore_init( feature_configs: dict[str, FeatureConfig], table_stacking: str | Sequence[str] | Sequence[Sequence[str]], ) -> None: - if not self._has_sparsecore(): + if not self.has_sparsecores(): + if jax.default_backend() != "tpu": + # If jax-tpu-embedding is installed, but running on GPU or CPU. + raise ValueError( + "The `sparsecore` placement is not available on " + f"{jax.default_backend()}" + ) + raise ValueError( - "Not sparse cores available, cannot use explicit sparsecore" + "No SparseCores available, cannot use explicit `sparsecore`" " placement." ) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py b/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py index 1dd3525d..420121e8 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py @@ -60,17 +60,9 @@ def _create_sparsecore_layout( return sparsecore_layout -def _num_sparsecores_per_device() -> int: - if test_utils.has_sparsecores(): - return jte_utils.num_sparsecores_per_device() - - # Default to one for non-sparsecore tests. - return 1 - - @pytest.mark.skipif( - keras.backend.backend() != "jax", - reason="Backend specific test", + not jax_distributed_embedding.DistributedEmbedding.has_sparsecores(), + reason="Requires SparseCores", ) class ShardedInitializerTest(parameterized.TestCase): @parameterized.product( @@ -103,8 +95,8 @@ def test_wrap_and_call( @pytest.mark.skipif( - keras.backend.backend() != "jax", - reason="Backend specific test", + not jax_distributed_embedding.DistributedEmbedding.has_sparsecores(), + reason="Requires SparseCores", ) class StackedTableInitializerTest(parameterized.TestCase): def test_sharded_matches_unsharded(self): @@ -130,7 +122,7 @@ def test_sharded_matches_unsharded(self): ) device_count = jax.device_count() - num_sc_per_device = _num_sparsecores_per_device() + num_sc_per_device = jte_utils.num_sparsecores_per_device() num_table_shards = device_count * num_sc_per_device # Convert to JAX and stack tables. @@ -208,7 +200,7 @@ def test_random_shards(self): } device_count = jax.device_count() - num_sc_per_device = _num_sparsecores_per_device() + num_sc_per_device = jte_utils.num_sparsecores_per_device() num_table_shards = device_count * num_sc_per_device table_stacking_lib.stack_tables( @@ -267,7 +259,7 @@ def test_compilability(self): } device_count = jax.device_count() - num_sc_per_device = _num_sparsecores_per_device() + num_sc_per_device = jte_utils.num_sparsecores_per_device() num_table_shards = device_count * num_sc_per_device table_stacking_lib.stack_tables( @@ -305,8 +297,8 @@ def my_initializer(shape: tuple[int, int], dtype: Any): @pytest.mark.skipif( - keras.backend.backend() != "jax", - reason="Backend specific test", + not jax_distributed_embedding.DistributedEmbedding.has_sparsecores(), + reason="Requires SparseCores", ) class DistributedEmbeddingLayerTest(parameterized.TestCase): @parameterized.product( diff --git a/keras_rs/src/layers/embedding/jax/embedding_lookup.py b/keras_rs/src/layers/embedding/jax/embedding_lookup.py index e5336d44..dde78b18 100644 --- a/keras_rs/src/layers/embedding/jax/embedding_lookup.py +++ b/keras_rs/src/layers/embedding/jax/embedding_lookup.py @@ -21,7 +21,7 @@ else: from jax.experimental.shard_map import shard_map as exp_shard_map - def shard_map( # type: ignore[misc] + def shard_map( # type: ignore[misc, no-redef] f: Any = None, /, *, diff --git a/keras_rs/src/layers/embedding/jax/embedding_lookup_test.py b/keras_rs/src/layers/embedding/jax/embedding_lookup_test.py index e76d79c7..4573e8a5 100644 --- a/keras_rs/src/layers/embedding/jax/embedding_lookup_test.py +++ b/keras_rs/src/layers/embedding/jax/embedding_lookup_test.py @@ -14,6 +14,7 @@ from jax_tpu_embedding.sparsecore.lib.nn import table_stacking from jax_tpu_embedding.sparsecore.utils import utils as jte_utils +from keras_rs.src.layers.embedding.jax import distributed_embedding from keras_rs.src.layers.embedding.jax import embedding_lookup from keras_rs.src.layers.embedding.jax import embedding_utils from keras_rs.src.layers.embedding.jax import test_utils @@ -133,7 +134,7 @@ def _create_table_and_feature_specs( stacked=[True, False], ) def test_forward_pass(self, ragged: bool, stacked: bool): - if not test_utils.has_sparsecores(): + if not distributed_embedding.DistributedEmbedding.has_sparsecores(): self.skipTest("Test requires sparsecores.") devices = jax.devices() @@ -215,7 +216,7 @@ def test_forward_pass(self, ragged: bool, stacked: bool): def test_model_sharding( self, ragged: bool, stacked: bool, num_model_shards: int ): - if not test_utils.has_sparsecores(): + if not distributed_embedding.DistributedEmbedding.has_sparsecores(): self.skipTest("Test requires sparsecores.") if num_model_shards > jax.device_count(): @@ -319,7 +320,7 @@ def test_backward_pass( stacked: bool, optimizer: embedding_spec.OptimizerSpec, ): - if not test_utils.has_sparsecores(): + if not distributed_embedding.DistributedEmbedding.has_sparsecores(): self.skipTest("Test requires sparsecores.") devices = jax.devices() @@ -426,7 +427,7 @@ def test_autograd( stacked: bool, optimizer: embedding_spec.OptimizerSpec, ): - if not test_utils.has_sparsecores(): + if not distributed_embedding.DistributedEmbedding.has_sparsecores(): self.skipTest("Test requires sparsecores.") devices = jax.devices() diff --git a/keras_rs/src/layers/embedding/jax/test_utils.py b/keras_rs/src/layers/embedding/jax/test_utils.py index 01cf9131..d6694d51 100644 --- a/keras_rs/src/layers/embedding/jax/test_utils.py +++ b/keras_rs/src/layers/embedding/jax/test_utils.py @@ -20,13 +20,6 @@ Shape: TypeAlias = tuple[int, ...] -def has_sparsecores() -> bool: - device_kind = jax.devices()[0].device_kind - if device_kind in ["TPU v5", "TPU v6 lite"]: - return True - return False - - def _round_up_to_multiple(value: int, multiple: int) -> int: return ((value + multiple - 1) // multiple) * multiple diff --git a/keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py b/keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py index 79e779d1..d7bc9fe6 100644 --- a/keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py @@ -61,15 +61,17 @@ def __init__( feature_configs, table_stacking=table_stacking, **kwargs ) - def _is_tpu_strategy(self, strategy: tf.distribute.Strategy) -> bool: + @classmethod + def _is_tpu_strategy(cls, strategy: tf.distribute.Strategy) -> bool: return isinstance( strategy, (tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy), ) - def _has_sparsecore(self) -> bool: + @classmethod + def has_sparsecores(cls) -> bool: strategy = tf.distribute.get_strategy() - if self._is_tpu_strategy(strategy): + if cls._is_tpu_strategy(strategy): tpu_embedding_feature = ( strategy.extended.tpu_hardware_feature.embedding_feature )