Skip to content

Commit 39802c7

Browse files
committed
Add soft capping to reversible embedding layer (#1718)
Forgetting the final output soft-cap is a really easy mistake, and worse, outputs will still look plausible for generations without the softcap, just with worse actual results. Adding this to our reversible embedding layer will be much more robust. As long as you use the layer to compute logits over the vocab, you can no longer forget the soft-cap.
1 parent 8499c92 commit 39802c7

File tree

4 files changed

+21
-14
lines changed

4 files changed

+21
-14
lines changed

keras_nlp/src/layers/modeling/reversible_embedding.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ class ReversibleEmbedding(keras.layers.Embedding):
5050
"padding" value that should be masked out.
5151
reverse_dtype: The dtype for the reverse projection computation.
5252
Defaults to the `compute_dtype` of the layer.
53+
logit_soft_cap: If `logit_soft_cap` is set and `reverse=True`, the
54+
output logits will be scaled by
55+
`tanh(logits / logit_soft_cap) * logit_soft_cap`. This narrows the
56+
range of output logits and can improve training.
5357
**kwargs: other keyword arguments passed to `keras.layers.Embedding`,
5458
including `name`, `trainable`, `dtype` etc.
5559
@@ -91,6 +95,7 @@ def __init__(
9195
embeddings_constraint=None,
9296
mask_zero=False,
9397
reverse_dtype=None,
98+
logit_soft_cap=None,
9499
**kwargs,
95100
):
96101
super().__init__(
@@ -104,6 +109,7 @@ def __init__(
104109
)
105110
self.tie_weights = tie_weights
106111
self.reverse_dtype = reverse_dtype
112+
self.logit_soft_cap = logit_soft_cap
107113

108114
def build(self, inputs_shape=None):
109115
super().build(inputs_shape)
@@ -125,7 +131,12 @@ def call(self, inputs, reverse=False):
125131
if self.reverse_dtype is not None:
126132
inputs = ops.cast(inputs, self.reverse_dtype)
127133
kernel = ops.cast(kernel, self.reverse_dtype)
128-
return ops.matmul(inputs, kernel)
134+
logits = ops.matmul(inputs, kernel)
135+
# Optionally soft-cap logits.
136+
if self.logit_soft_cap is not None:
137+
soft_cap = self.logit_soft_cap
138+
logits = ops.tanh(logits / soft_cap) * soft_cap
139+
return logits
129140

130141
return super().call(inputs)
131142

@@ -135,6 +146,7 @@ def get_config(self):
135146
{
136147
"tie_weights": self.tie_weights,
137148
"reverse_dtype": self.reverse_dtype,
149+
"logit_soft_cap": self.logit_soft_cap,
138150
}
139151
)
140152
return config

keras_nlp/src/layers/modeling/reversible_embedding_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def test_layer_behaviors_tied(self, tie_weights):
3939
"output_dim": 32,
4040
"tie_weights": tie_weights,
4141
"embeddings_initializer": "HeNormal",
42+
"logit_soft_cap": 50,
4243
},
4344
input_data=random.randint(minval=0, maxval=100, shape=(4, 10)),
4445
expected_output_shape=(4, 10, 32),
@@ -80,6 +81,12 @@ def test_correctness(self):
8081
out = layer(np.array(([[1.0, 1.0]])), reverse=True)
8182
self.assertAllClose(out, np.array([[0.0, 4.0, 6.0]]))
8283

84+
layer = ReversibleEmbedding(input_dim=3, output_dim=2, logit_soft_cap=5)
85+
layer.build()
86+
layer.embeddings.assign(np.array([[0.0, 0.0], [2.0, 2.0], [3.0, 3.0]]))
87+
out = layer(np.array(([[1.0, 1.0]])), reverse=True)
88+
self.assertAllClose(out, np.array([[0.0, 3.320184, 4.168273]]))
89+
8390
def test_reverse_dtype(self):
8491
embedding = ReversibleEmbedding(100, 16, reverse_dtype="float32")
8592
input_data = ops.ones(shape=(4, 10, 16))

keras_nlp/src/models/gemma/gemma_backbone.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def __init__(
132132
seed=None,
133133
),
134134
dtype=dtype,
135+
logit_soft_cap=final_logit_soft_cap,
135136
name="token_embedding",
136137
)
137138
self.transformer_layers = []

keras_nlp/src/models/gemma/gemma_causal_lm.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,6 @@ def call_with_cache(
227227
cache = ops.stack(caches, axis=1)
228228
hidden_states = x = self.backbone.layer_norm(x)
229229
logits = self.backbone.token_embedding(x, reverse=True)
230-
231-
if self.backbone.final_logit_soft_cap is not None:
232-
logits = ops.divide(logits, self.backbone.final_logit_soft_cap)
233-
logits = ops.multiply(
234-
ops.tanh(logits), self.backbone.final_logit_soft_cap
235-
)
236-
237230
return logits, hidden_states, cache
238231

239232
def _build_cache(self, token_ids):
@@ -445,12 +438,6 @@ def default_layer_intercept_fn(x, unused_i):
445438
x = self.backbone.layer_norm(x)
446439
logits = self.backbone.token_embedding(x, reverse=True)
447440

448-
if self.backbone.final_logit_soft_cap is not None:
449-
logits = ops.divide(logits, self.backbone.final_logit_soft_cap)
450-
logits = ops.multiply(
451-
ops.tanh(logits), self.backbone.final_logit_soft_cap
452-
)
453-
454441
if scoring_mode == "logits":
455442
return logits
456443

0 commit comments

Comments
 (0)