Skip to content

Commit 09c8411

Browse files
adityagupta1089recml authors
authored andcommitted
Change heuristic based parameter static_buffer_size_multiplier to FDO tunable parameter feature_spec.table_spec.[stacked_table_spec.]suggested_coo_buffer_size.
PiperOrigin-RevId: 761725839
1 parent d5abd99 commit 09c8411

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

recml/layers/linen/sparsecore.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from recml.core.ops import embedding_ops
3131
import tensorflow as tf
3232

33+
3334
with epy.lazy_imports():
3435
# pylint: disable=g-import-not-at-top
3536
from jax_tpu_embedding.sparsecore.lib.nn import embedding
@@ -99,8 +100,6 @@ class SparsecoreEmbedder:
99100
fixed mapping is used to determine this based on device 0. This may fail
100101
on newer TPU architectures if the mapping is not updated of if device 0 is
101102
not a TPU device with a sparsecore.
102-
static_buffer_size_multiplier: The multiplier to use for the static buffer
103-
size. Defaults to 256.
104103
105104
Example usage:
106105
```python
@@ -142,7 +141,6 @@ def __call__(self, inputs: Mapping[str, jax.Array]) -> jax.Array:
142141
specs: Mapping[str, EmbeddingSpec]
143142
optimizer: OptimizerSpec
144143
sharding_strategy: str = 'MOD'
145-
static_buffer_size_multiplier: int = 256
146144

147145
def __post_init__(self):
148146
self._feature_specs = None
@@ -254,7 +252,6 @@ def _preprocessor(inputs):
254252
global_device_count=jax.device_count(),
255253
num_sc_per_device=self._num_sc_per_device,
256254
sharding_strategy=self.sharding_strategy,
257-
static_buffer_size_multiplier=self.static_buffer_size_multiplier,
258255
allow_id_dropping=False,
259256
)
260257

0 commit comments

Comments
 (0)