forked from keras-team/keras-rs
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdistributed_embedding.py
More file actions
445 lines (397 loc) · 17.7 KB
/
distributed_embedding.py
File metadata and controls
445 lines (397 loc) · 17.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
from typing import Any, Callable, Sequence, TypeAlias
import keras
import tensorflow as tf
from keras_rs.src import types
from keras_rs.src.layers.embedding import base_distributed_embedding
from keras_rs.src.layers.embedding import distributed_embedding_config
from keras_rs.src.layers.embedding.tensorflow import config_conversion
from keras_rs.src.utils import keras_utils
FeatureConfig = distributed_embedding_config.FeatureConfig
TableConfig = distributed_embedding_config.TableConfig
# Placeholder of tf.tpu.experimental.embedding._Optimizer which is not exposed.
TfTpuOptimizer: TypeAlias = Any
GRADIENT_TRAP_DUMMY_NAME = "_gradient_trap_dummy"
EMBEDDING_FEATURE_V1 = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V1
EMBEDDING_FEATURE_V2 = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V2
UNSUPPORTED = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.UNSUPPORTED
class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
"""TensorFlow implementation of the TPU embedding layer."""
def __init__(
self,
feature_configs: types.Nested[
FeatureConfig | tf.tpu.experimental.embedding.FeatureConfig
],
*,
table_stacking: (
str | Sequence[str] | Sequence[Sequence[str]]
) = "auto",
update_stats: bool = False,
**kwargs: Any,
) -> None:
# `update_stats` is supported only on JAX.
if update_stats:
raise ValueError(
"`update_stats` cannot be True for the TensorFlow backend."
)
# Intercept arguments that are supported only on TensorFlow.
self._optimizer = kwargs.pop("optimizer", None)
self._pipeline_execution_with_tensor_core = kwargs.pop(
"pipeline_execution_with_tensor_core", False
)
self._sparse_core_embedding_config = kwargs.pop(
"sparse_core_embedding_config", None
)
# Mark as True by default for `_verify_input_shapes`. This will be
# updated in `_sparsecore_init` if applicable.
self._using_keras_rs_configuration = True
super().__init__(
feature_configs, table_stacking=table_stacking, **kwargs
)
@classmethod
def _is_tpu_strategy(cls, strategy: tf.distribute.Strategy) -> bool:
return isinstance(
strategy,
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy),
)
@classmethod
def has_sparsecores(cls) -> bool:
strategy = tf.distribute.get_strategy()
if cls._is_tpu_strategy(strategy):
tpu_embedding_feature = (
strategy.extended.tpu_hardware_feature.embedding_feature
)
return tpu_embedding_feature in (
EMBEDDING_FEATURE_V2,
EMBEDDING_FEATURE_V1,
)
return False
@keras_utils.no_automatic_dependency_tracking
def _sparsecore_init(
self,
feature_configs: dict[
str,
FeatureConfig | tf.tpu.experimental.embedding.FeatureConfig,
],
table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
) -> None:
self._table_stacking = table_stacking
strategy = tf.distribute.get_strategy()
if not self._is_tpu_strategy(strategy):
raise ValueError(
"Placement to sparsecore was requested, however, we are not "
"running under a TPU strategy."
)
tpu_embedding_feature = (
strategy.extended.tpu_hardware_feature.embedding_feature
)
self._using_keras_rs_configuration = isinstance(
next(iter(feature_configs.values())), FeatureConfig
)
if self._using_keras_rs_configuration:
if self._sparse_core_embedding_config is not None:
raise ValueError(
"The `sparse_core_embedding_config` argument is only "
"supported when using "
"`tf.tpu.experimental.embedding.FeatureConfig` instances "
"for the configuration."
)
self._tpu_feature_configs, self._sparse_core_embedding_config = (
config_conversion.keras_to_tf_tpu_configuration(
feature_configs,
table_stacking,
strategy.num_replicas_in_sync,
)
)
if tpu_embedding_feature == EMBEDDING_FEATURE_V1:
# Remove auto-generated SparseCoreEmbeddingConfig, which is not
# used.
self._sparse_core_embedding_config = None
else:
if table_stacking != "auto":
raise ValueError(
"The `table_stacking` argument is not supported when using "
"`tf.tpu.experimental.embedding.FeatureConfig` for the "
"configuration. You can use the `disable_table_stacking` "
"attribute of "
"`tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig` "
"to disable table stacking."
)
if (
tpu_embedding_feature == EMBEDDING_FEATURE_V1
and self._sparse_core_embedding_config is not None
):
raise ValueError(
"The `sparse_core_embedding_config` argument is not "
"supported with this TPU generation."
)
self._tpu_feature_configs = (
config_conversion.clone_tf_tpu_feature_configs(feature_configs)
)
self._tpu_optimizer = config_conversion.to_tf_tpu_optimizer(
self._optimizer
)
if tpu_embedding_feature == EMBEDDING_FEATURE_V1:
self._tpu_embedding = tf.tpu.experimental.embedding.TPUEmbedding(
self._tpu_feature_configs,
self._tpu_optimizer,
self._pipeline_execution_with_tensor_core,
)
self._v1_call_id = 0
elif tpu_embedding_feature == EMBEDDING_FEATURE_V2:
self._tpu_embedding = tf.tpu.experimental.embedding.TPUEmbeddingV2(
self._tpu_feature_configs,
self._tpu_optimizer,
self._pipeline_execution_with_tensor_core,
self._sparse_core_embedding_config,
)
elif tpu_embedding_feature == UNSUPPORTED:
raise ValueError(
"Placement to sparsecore was requested, however, this TPU does "
"not support it."
)
elif tpu_embedding_feature != UNSUPPORTED:
raise ValueError(
f"Unsupported TPU embedding feature: {tpu_embedding_feature}."
)
# We need at least one trainable variable for the gradient trap to work.
# Note that the Python attribute name "_gradient_trap_dummy" should
# match the name of the variable GRADIENT_TRAP_DUMMY_NAME.
self._gradient_trap_dummy = self.add_weight(
name=GRADIENT_TRAP_DUMMY_NAME,
shape=(1,),
initializer=tf.zeros_initializer(),
trainable=True,
dtype=tf.float32,
)
def compute_output_shape(
self, input_shapes: types.Nested[types.Shape]
) -> types.Nested[types.Shape]:
if self._using_keras_rs_configuration:
return super().compute_output_shape(input_shapes)
def _compute_output_shape(
feature_config: tf.tpu.experimental.embedding.FeatureConfig,
input_shape: types.Shape,
) -> types.Shape:
if len(input_shape) < 1:
raise ValueError(
f"Received input shape {input_shape}. Rank must be 1 or "
"above."
)
max_sequence_length: int = feature_config.max_sequence_length
embed_dim = feature_config.table.dim
if (
feature_config.output_shape is not None
and feature_config.output_shape.rank is not None
):
return tuple(feature_config.output_shape.as_list())
elif (
len(input_shape) == 2
and input_shape[-1] != 1
and max_sequence_length > 0
):
# Update the input shape with the max sequence length. Only
# update when:
# 1. Input feature is 2D ragged or sparse tensor.
# 2. Output shape is not set and max sequence length is set.
return tuple(input_shape[:-1]) + (
max_sequence_length,
embed_dim,
)
elif len(input_shape) == 1:
return tuple(input_shape) + (embed_dim,)
else:
return tuple(input_shape[:-1]) + (embed_dim,)
output_shapes: types.Nested[types.Shape] = (
keras.tree.map_structure_up_to(
self._feature_configs,
_compute_output_shape,
self._feature_configs,
input_shapes,
)
)
return output_shapes
def _sparsecore_build(self, input_shapes: dict[str, types.Shape]) -> None:
if isinstance(
self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbedding
):
tf_input_shapes = keras.tree.map_shape_structure(
tf.TensorShape, input_shapes
)
tpu_embedding_build = tf.autograph.to_graph(
self._tpu_embedding.build, recursive=False
)
tpu_embedding_build(
self._tpu_embedding, per_replica_input_shapes=tf_input_shapes
)
elif isinstance(
self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbeddingV2
):
self._tpu_embedding.build()
def _sparsecore_call(
self,
inputs: dict[str, types.Tensor],
weights: dict[str, types.Tensor] | None = None,
training: bool = False,
) -> dict[str, types.Tensor]:
del training # Unused.
strategy = tf.distribute.get_strategy()
if not self._is_tpu_strategy(strategy):
raise RuntimeError(
"DistributedEmbedding needs to be called under a TPUStrategy "
"for features placed on the embedding feature but is being "
f"called under strategy {strategy}. Please use `strategy.run` "
"when calling this layer."
)
if isinstance(
self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbedding
):
return self._tpu_embedding_lookup_v1(
self._tpu_embedding, inputs, weights
)
elif isinstance(
self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbeddingV2
):
return self._tpu_embedding_lookup_v2(
self._tpu_embedding, inputs, weights
)
else:
raise ValueError(
"DistributedEmbedding is receiving features to lookup on the "
"TPU embedding feature but no such feature was configured."
)
def _sparsecore_get_embedding_tables(self) -> dict[str, types.Tensor]:
tables: dict[str, types.Tensor] = {}
strategy = tf.distribute.get_strategy()
if not self._is_tpu_strategy(strategy):
raise RuntimeError(
"`DistributedEmbedding.get_embedding_tables` needs to be "
"called under the TPUStrategy that DistributedEmbedding was "
f"created with, but is being called under strategy {strategy}. "
"Please use `with strategy.scope()` when calling "
"`get_embedding_tables`."
)
tpu_hardware = strategy.extended.tpu_hardware_feature
num_sc_per_device = tpu_hardware.num_embedding_devices_per_chip
num_shards = strategy.num_replicas_in_sync * num_sc_per_device
def populate_table(
feature_config: tf.tpu.experimental.embedding.FeatureConfig,
) -> None:
table_name = feature_config.table.name
if table_name in tables:
return
embedding_dim = feature_config.table.dim
table = self._tpu_embedding.embedding_tables[table_name]
# This table has num_sparse_cores mod shards, so we need to slice,
# reconcat and reshape.
table_shards = [
shard.numpy()[:, :embedding_dim] for shard in table.values
]
full_table = keras.ops.concatenate(table_shards, axis=0)
full_table = keras.ops.concatenate(
keras.ops.split(full_table, num_shards, axis=0), axis=1
)
full_table = keras.ops.reshape(full_table, [-1, embedding_dim])
tables[table_name] = full_table[
: feature_config.table.vocabulary_size, :
]
keras.tree.map_structure(populate_table, self._tpu_feature_configs)
return tables
def _verify_input_shapes(
self, input_shapes: types.Nested[types.Shape]
) -> None:
if self._using_keras_rs_configuration:
return super()._verify_input_shapes(input_shapes)
# `tf.tpu.experimental.embedding.FeatureConfig` does not provide any
# information about the input shape, so there is nothing to verify.
def _tpu_embedding_lookup_v1(
self,
tpu_embedding: tf.tpu.experimental.embedding.TPUEmbedding,
inputs: dict[str, types.Tensor],
weights: dict[str, types.Tensor] | None = None,
) -> dict[str, types.Tensor]:
# Each call to this function increments the _v1_call_id by 1, this
# allows us to tag each of the main embedding ops with this call id so
# that we know during graph rewriting passes which ops correspond to the
# same layer call.
self._v1_call_id += 1
name = str(self._v1_call_id)
# Set training to true, even during eval. When name is set, this will
# trigger a pass that updates the training based on if there is a send
# gradients with the same name.
tpu_embedding.enqueue(inputs, weights, training=True, name=name)
@tf.custom_gradient # type: ignore
def gradient_trap(
dummy: types.Tensor,
) -> tuple[
list[types.Tensor], Callable[[tuple[types.Tensor]], types.Tensor]
]:
"""Register a gradient function for activation."""
activations = tpu_embedding.dequeue(name=name)
def grad(*grad_wrt_activations: types.Tensor) -> types.Tensor:
"""Gradient function."""
# Since the output were flattened, the gradients are also
# flattened. Pack them back into the correct nested structure.
gradients = tf.nest.pack_sequence_as(
self._placement_to_path_to_feature_config["sparsecore"],
grad_wrt_activations,
)
tpu_embedding.apply_gradients(gradients, name=name)
# This is the gradient for the input variable.
return tf.zeros_like(dummy)
# Custom gradient functions don't like nested structures of tensors,
# so we flatten them here.
return tf.nest.flatten(activations), grad
activations_with_trap = gradient_trap(self._gradient_trap_dummy.value)
result: dict[str, types.Tensor] = tf.nest.pack_sequence_as(
self._placement_to_path_to_feature_config["sparsecore"],
activations_with_trap,
)
return result
def _tpu_embedding_lookup_v2(
self,
tpu_embedding: tf.tpu.experimental.embedding.TPUEmbeddingV2,
inputs: dict[str, types.Tensor],
weights: dict[str, types.Tensor] | None = None,
) -> dict[str, types.Tensor]:
@tf.custom_gradient # type: ignore
def gradient_trap(
dummy: types.Tensor,
) -> tuple[
list[types.Tensor], Callable[[tuple[types.Tensor]], types.Tensor]
]:
"""Register a gradient function for activation."""
activations, preserved_result = tpu_embedding(inputs, weights)
def grad(*grad_wrt_activations: types.Tensor) -> types.Tensor:
"""Gradient function."""
# Since the output were flattened, the gradients are also
# flattened. Pack them back into the correct nested structure.
gradients = tf.nest.pack_sequence_as(
self._placement_to_path_to_feature_config["sparsecore"],
grad_wrt_activations,
)
tpu_embedding.apply_gradients(
gradients, preserved_outputs=preserved_result
)
# This is the gradient for the input variable.
return tf.zeros_like(dummy)
# Custom gradient functions don't like nested structures of tensors,
# so we flatten them here.
return tf.nest.flatten(activations), grad
activations_with_trap = gradient_trap(self._gradient_trap_dummy)
result: dict[str, types.Tensor] = tf.nest.pack_sequence_as(
self._placement_to_path_to_feature_config["sparsecore"],
activations_with_trap,
)
return result
def _trackable_children(
self, save_type: str = "checkpoint", **kwargs: dict[str, Any]
) -> dict[str, Any]:
# Remove dummy variable, we don't want it in checkpoints.
children: dict[str, Any] = super()._trackable_children(
save_type, **kwargs
)
children.pop(GRADIENT_TRAP_DUMMY_NAME, None)
return children
DistributedEmbedding.__doc__ = (
base_distributed_embedding.DistributedEmbedding.__doc__
)