Skip to content

Commit b0226f8

Browse files
Hilly12recml authors
authored andcommitted
Add DLRM-V2 with sparsecore.
PiperOrigin-RevId: 745051634
1 parent c43d6ae commit b0226f8

File tree

10 files changed

+1135
-38
lines changed

10 files changed

+1135
-38
lines changed

recml/core/data/iterator.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,15 @@ def __next__(self) -> clu_data.Element:
5757
if self._prefetched_batch is not None:
5858
batch = self._prefetched_batch
5959
self._prefetched_batch = None
60-
return batch
61-
62-
batch = next(self._iterator)
63-
if self._postprocessor is not None:
64-
batch = self._postprocessor(batch)
60+
else:
61+
batch = next(self._iterator)
62+
if self._postprocessor is not None:
63+
batch = self._postprocessor(batch)
6564

6665
def _maybe_to_numpy(
67-
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor,
66+
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray,
6867
) -> np.ndarray | tf.SparseTensor | tf.RaggedTensor:
69-
if isinstance(x, (tf.SparseTensor, tf.RaggedTensor)):
68+
if isinstance(x, (tf.SparseTensor, tf.RaggedTensor, np.ndarray)):
7069
return x
7170
if hasattr(x, "_numpy"):
7271
numpy = x._numpy() # pylint: disable=protected-access
@@ -83,13 +82,16 @@ def _maybe_to_numpy(
8382
@property
8483
def element_spec(self) -> clu_data.ElementSpec:
8584
if self._element_spec is not None:
86-
batch = self._element_spec
87-
else:
88-
batch = self.__next__()
89-
self._prefetched_batch = batch
85+
return self._element_spec
86+
87+
batch = next(self._iterator)
88+
if self._postprocessor is not None:
89+
batch = self._postprocessor(batch)
90+
91+
self._prefetched_batch = batch
9092

9193
def _to_element_spec(
92-
x: np.ndarray | tf.SparseTensor | tf.RaggedTensor,
94+
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray,
9395
) -> clu_data.ArraySpec:
9496
if isinstance(x, tf.SparseTensor):
9597
return clu_data.ArraySpec(
@@ -101,6 +103,10 @@ def _to_element_spec(
101103
dtype=x.dtype.as_numpy_dtype, # pylint: disable=attribute-error
102104
shape=tuple(x.shape.as_list()), # pylint: disable=attribute-error
103105
)
106+
if isinstance(x, tf.Tensor):
107+
return clu_data.ArraySpec(
108+
dtype=x.dtype.as_numpy_dtype, shape=tuple(x.shape.as_list())
109+
)
104110
return clu_data.ArraySpec(dtype=x.dtype, shape=tuple(x.shape))
105111

106112
element_spec = tf.nest.map_structure(_to_element_spec, batch)

recml/core/ops/embedding_ops.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2024 RecML authors <[email protected]>.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Embedding lookup ops."""
15+
16+
from collections.abc import Mapping, Sequence
17+
import dataclasses
18+
import functools
19+
20+
import jax
21+
from jax.experimental import shard_map
22+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
23+
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
24+
25+
26+
@dataclasses.dataclass
27+
class SparsecoreParams:
28+
"""Embedding parameters."""
29+
30+
feature_specs: embedding.Nested[embedding_spec.FeatureSpec]
31+
abstract_mesh: jax.sharding.AbstractMesh
32+
data_axes: Sequence[str | None]
33+
embedding_axes: Sequence[str | None]
34+
sharding_strategy: str
35+
36+
37+
@functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
38+
def sparsecore_lookup(
39+
sparsecore_params: SparsecoreParams,
40+
tables: Mapping[str, tuple[jax.Array, ...]],
41+
csr_inputs: tuple[jax.Array, ...],
42+
):
43+
return shard_map.shard_map(
44+
functools.partial(
45+
embedding.tpu_sparse_dense_matmul,
46+
global_device_count=sparsecore_params.abstract_mesh.size,
47+
feature_specs=sparsecore_params.feature_specs,
48+
sharding_strategy=sparsecore_params.sharding_strategy,
49+
),
50+
mesh=sparsecore_params.abstract_mesh,
51+
in_specs=(
52+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
53+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
54+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
55+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
56+
jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes),
57+
),
58+
out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
59+
check_rep=False,
60+
)(*csr_inputs, tables)
61+
62+
63+
def _emb_lookup_fwd(
64+
sparsecore_params: SparsecoreParams,
65+
tables: Mapping[str, tuple[jax.Array, ...]],
66+
csr_inputs: tuple[jax.Array, ...],
67+
):
68+
out = sparsecore_lookup(sparsecore_params, tables, csr_inputs)
69+
return out, (tables, csr_inputs)
70+
71+
72+
def _emb_lookup_bwd(
73+
sparsecore_params: SparsecoreParams,
74+
res: tuple[Mapping[str, tuple[jax.Array, ...]], tuple[jax.Array, ...]],
75+
gradients: embedding.Nested[jax.Array],
76+
) -> tuple[embedding.Nested[jax.Array], None]:
77+
"""Backward pass for embedding lookup."""
78+
(tables, csr_inputs) = res
79+
80+
emb_table_grads = shard_map.shard_map(
81+
functools.partial(
82+
embedding.tpu_sparse_dense_matmul_grad,
83+
feature_specs=sparsecore_params.feature_specs,
84+
sharding_strategy=sparsecore_params.sharding_strategy,
85+
),
86+
mesh=sparsecore_params.abstract_mesh,
87+
in_specs=(
88+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
89+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
90+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
91+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
92+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
93+
jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes),
94+
),
95+
out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
96+
check_rep=False,
97+
)(gradients, *csr_inputs, tables)
98+
99+
# `tpu_sparse_dense_matmul_grad` returns a general mapping (usually a dict).
100+
# It may not be the same type as the embedding table (e.g. FrozenDict).
101+
# Here we use flatten / unflatten to ensure the types are the same.
102+
emb_table_grads = jax.tree.unflatten(
103+
jax.tree.structure(tables), jax.tree.leaves(emb_table_grads)
104+
)
105+
106+
return emb_table_grads, None
107+
108+
109+
sparsecore_lookup.defvjp(_emb_lookup_fwd, _emb_lookup_bwd)

recml/core/training/jax.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from clu import periodic_actions
2727
import clu.metrics as clu_metrics
2828
from flax import struct
29+
import flax.linen as nn
2930
import jax
3031
import jax.numpy as jnp
3132
import keras
@@ -67,43 +68,85 @@ class JaxState(struct.PyTreeNode, Generic[MetaT]):
6768
step: A counter of the current step of the job. It starts at zero and it is
6869
incremented by 1 on a call to `state.update(...)`. This should be a Jax
6970
array and not a Python integer.
70-
apply: A function that can be used to apply the forward pass of the model.
71-
For Flax models this is usually set to `model.apply`.
7271
params: A pytree of trainable variables that will be updated by `tx` and
7372
used in `apply`.
7473
tx: An optax gradient transformation that will be used to update the
7574
parameters contained in `params` on a call to `state.update(...)`.
7675
opt_state: The optimizer state for `tx`. This is usually created by calling
7776
`tx.init(params)`.
77+
_apply: An optional function that can be used to apply the forward pass of
78+
the model. For Flax models this is usually set to `model.apply` while for
79+
Haiku models this is usually set to `transform.apply`.
80+
_model: An optional reference to a stateless Flax model for convenience.
7881
mutable: A pytree of mutable variables that are used by `apply`.
7982
meta: Arbitrary metadata that is recorded on the state. This can be useful
8083
for tracking additional references in the state.
8184
"""
8285

8386
step: jax.Array
84-
apply: Callable[..., Any] = struct.field(pytree_node=False)
8587
params: PyTree = struct.field(pytree_node=True)
8688
tx: optax.GradientTransformation = struct.field(pytree_node=False)
8789
opt_state: optax.OptState = struct.field(pytree_node=True)
8890
mutable: PyTree = struct.field(pytree_node=True, default_factory=dict)
8991
meta: MetaT = struct.field(pytree_node=False, default_factory=dict)
92+
_apply: Callable[..., Any] | None = struct.field(
93+
pytree_node=False, default_factory=None
94+
)
95+
_model: nn.Module | None = struct.field(pytree_node=False, default=None)
96+
97+
@property
98+
def model(self) -> nn.Module:
99+
"""Returns a reference to the model used to create the state."""
100+
if self._model is None:
101+
raise ValueError("No Flax `model` is set on the state.")
102+
return self._model
103+
104+
def apply(self, *args, **kwargs) -> Any:
105+
"""Applies the forward pass of the model."""
106+
if self._apply is None:
107+
raise ValueError("No `apply` function is set on the state.")
108+
return self._apply(*args, **kwargs)
90109

91110
@classmethod
92111
def create(
93112
cls,
94113
*,
95-
apply: Callable[..., Any],
114+
apply: Callable[..., Any] | None = None,
115+
model: nn.Module | None = None,
96116
params: PyTree,
97117
tx: optax.GradientTransformation,
98118
**kwargs,
99119
) -> Self:
100-
"""Creates a new instance from a Jax apply function and Optax optimizer."""
120+
"""Creates a new instance from a Jax model / apply fn and Optax optimizer.
121+
122+
Args:
123+
apply: A function that can be used to apply the forward pass of the model.
124+
For Flax models this is usually set to `model.apply`. This cannot be set
125+
along with `model`.
126+
model: A reference to a stateless Flax model. This cannot be set along
127+
with `apply`. When set the `apply` attribute of the state will be set to
128+
`model.apply`.
129+
params: A pytree of trainable variables that will be updated by `tx` and
130+
used in `apply`.
131+
tx: An optax gradient transformation that will be used to update the
132+
parameters contained in `params` on a call to `state.update(...)`.
133+
**kwargs: Other updates to set on the new state.
134+
135+
Returns:
136+
An new instance of the state.
137+
"""
138+
if apply is not None and model is not None:
139+
raise ValueError("Only one of `apply` or `model` can be provided.")
140+
elif model is not None:
141+
apply = model.apply
142+
101143
return cls(
102144
step=jnp.zeros([], dtype=jnp.int32),
103-
apply=apply,
104145
params=params,
105146
tx=tx,
106147
opt_state=tx.init(params),
148+
_apply=apply,
149+
_model=model,
107150
**kwargs,
108151
)
109152

recml/core/training/optax_factory.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ def _default_weight_decay_mask(params: optax.Params) -> optax.Params:
2929

3030

3131
def _regex_mask(regex: str) -> Callable[[optax.Params], optax.Params]:
32-
"""Returns a weight decay mask that applies to parameters matching a regex."""
32+
"""Returns a mask that applies to parameters matching a regex."""
3333

3434
def _matches_regex(path: tuple[str, ...], _: Any) -> bool:
35-
key = "/".join([jax.tree_util.keystr((k,), simple=True) for k in path])
35+
key = '/'.join([jax.tree_util.keystr((k,), simple=True) for k in path])
3636
return re.fullmatch(regex, key) is not None
3737

3838
def _mask(params: optax.Params) -> optax.Params:
@@ -54,6 +54,8 @@ class OptimizerFactory(types.Factory[optax.GradientTransformation]):
5454
magnitude of the gradients during optimization. Defaults to None.
5555
weight_decay_mask: The weight decay mask to use when applying weight decay.
5656
Defaults applying weight decay to all non-1D parameters.
57+
freeze_mask: Optional mask to freeze parameters during optimization.
58+
Defaults to None.
5759
5860
Example usage:
5961
@@ -78,6 +80,7 @@ class OptimizerFactory(types.Factory[optax.GradientTransformation]):
7880
weight_decay_mask: str | Callable[[optax.Params], optax.Params] = (
7981
_default_weight_decay_mask
8082
)
83+
freeze_mask: str | Callable[[optax.Params], optax.Params] | None = None
8184

8285
def make(self) -> optax.GradientTransformation:
8386
if self.grad_clip_norm is not None:
@@ -99,13 +102,30 @@ def make(self) -> optax.GradientTransformation:
99102
else:
100103
weight_decay = optax.identity()
101104

102-
return optax.chain(*[
105+
tx = optax.chain(*[
103106
apply_clipping,
104107
self.scaling,
105108
weight_decay,
106109
lr_scaling,
107110
])
108111

112+
if self.freeze_mask is not None:
113+
if isinstance(self.freeze_mask, str):
114+
mask = _regex_mask(self.freeze_mask)
115+
else:
116+
mask = self.freeze_mask
117+
118+
def _param_labels(params: optax.Params) -> optax.Params:
119+
return jax.tree.map(
120+
lambda p: 'frozen' if mask(p) else 'trainable', params
121+
)
122+
123+
tx = optax.multi_transform(
124+
transforms={'trainable': tx, 'frozen': optax.set_to_zero()},
125+
param_labels=_param_labels,
126+
)
127+
return tx
128+
109129

110130
class AdamFactory(types.Factory[optax.GradientTransformation]):
111131
"""Adam optimizer factory.
@@ -121,6 +141,8 @@ class AdamFactory(types.Factory[optax.GradientTransformation]):
121141
magnitude of the gradients during optimization. Defaults to None.
122142
weight_decay_mask: The weight decay mask to use when applying weight decay.
123143
Defaults applying weight decay to all non-1D parameters.
144+
freeze_mask: Optional mask to freeze parameters during optimization.
145+
Defaults to None.
124146
125147
Example usage:
126148
```
@@ -143,6 +165,7 @@ class AdamFactory(types.Factory[optax.GradientTransformation]):
143165
weight_decay_mask: str | Callable[[optax.Params], optax.Params] = (
144166
_default_weight_decay_mask
145167
)
168+
freeze_mask: str | Callable[[optax.Params], optax.Params] | None = None
146169

147170
def make(self) -> optax.GradientTransformation:
148171
return OptimizerFactory(
@@ -164,6 +187,8 @@ class AdagradFactory(types.Factory[optax.GradientTransformation]):
164187
eps: The epsilon coefficient for the Adagrad optimizer. Defaults to 1e-7.
165188
grad_clip_norm: Optional gradient clipping norm to limit the maximum
166189
magnitude of the gradients during optimization. Defaults to None.
190+
freeze_mask: Optional mask to freeze parameters during optimization.
191+
Defaults to None.
167192
168193
Example usage:
169194
```
@@ -175,6 +200,7 @@ class AdagradFactory(types.Factory[optax.GradientTransformation]):
175200
initial_accumulator_value: float = 0.1
176201
eps: float = 1e-7
177202
grad_clip_norm: float | None = None
203+
freeze_mask: str | Callable[[optax.Params], optax.Params] | None = None
178204

179205
def make(self) -> optax.GradientTransformation:
180206
return OptimizerFactory(

0 commit comments

Comments
 (0)