Skip to content

Commit 5cf2c5e

Browse files
authored
Add a mixed precision test and fix mixed precision errors for layers (#1242)
* Add a mixed precision test and fix mixed precision errors for layers * Address comments * fix torch cpu
1 parent ca20190 commit 5cf2c5e

12 files changed

+120
-20
lines changed

keras_nlp/layers/modeling/cached_multi_head_attention_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from keras_nlp.backend import config
1516
from keras_nlp.backend import ops
1617
from keras_nlp.layers.modeling.cached_multi_head_attention import (
1718
CachedMultiHeadAttention,
@@ -34,6 +35,9 @@ def test_layer_behaviors(self):
3435
expected_output_shape=(2, 4, 6),
3536
expected_num_trainable_weights=8,
3637
expected_num_non_trainable_variables=1,
38+
# tf.keras does not handle mixed precision correctly when not set
39+
# globally.
40+
run_mixed_precision_check=config.multi_backend(),
3741
)
3842

3943
def test_cache_call_is_correct(self):

keras_nlp/layers/modeling/f_net_encoder.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,13 @@ def build(self, inputs_shape):
9797
# Layer Norm layers.
9898
self._mixing_layer_norm = keras.layers.LayerNormalization(
9999
epsilon=self.layer_norm_epsilon,
100+
dtype=self.dtype_policy,
100101
name="mixing_layer_norm",
101102
)
102103
self._mixing_layer_norm.build(inputs_shape)
103104
self._output_layer_norm = keras.layers.LayerNormalization(
104105
epsilon=self.layer_norm_epsilon,
106+
dtype=self.dtype_policy,
105107
name="output_layer_norm",
106108
)
107109
self._output_layer_norm.build(inputs_shape)
@@ -112,19 +114,25 @@ def build(self, inputs_shape):
112114
activation=self.activation,
113115
kernel_initializer=clone_initializer(self.kernel_initializer),
114116
bias_initializer=clone_initializer(self.bias_initializer),
117+
dtype=self.dtype_policy,
115118
name="intermediate_dense",
116119
)
117120
self._intermediate_dense.build(inputs_shape)
118121
self._output_dense = keras.layers.Dense(
119122
feature_size,
120123
kernel_initializer=clone_initializer(self.kernel_initializer),
121124
bias_initializer=clone_initializer(self.bias_initializer),
125+
dtype=self.dtype_policy,
122126
name="output_dense",
123127
)
124128
self._output_dense.build(
125129
self._intermediate_dense.compute_output_shape(inputs_shape)
126130
)
127-
self._output_dropout = keras.layers.Dropout(rate=self.dropout)
131+
self._output_dropout = keras.layers.Dropout(
132+
rate=self.dropout,
133+
dtype=self.dtype_policy,
134+
name="output_dropout",
135+
)
128136
self.built = True
129137

130138
def call(self, inputs):
@@ -140,9 +148,12 @@ def call(self, inputs):
140148

141149
def fourier_transform(input):
142150
# Apply FFT on the input and take the real part.
151+
input_dtype = input.dtype
152+
# FFT transforms do not support float16.
153+
input = ops.cast(input, "float32")
143154
real_in, imaginary_in = (input, ops.zeros_like(input))
144155
real_out, _ = ops.fft2((real_in, imaginary_in))
145-
return real_out
156+
return ops.cast(real_out, input_dtype)
146157

147158
def add_and_norm(input1, input2, norm_layer):
148159
return norm_layer(input1 + input2)

keras_nlp/layers/modeling/masked_lm_head.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,12 @@ def build(self, inputs_shape, mask_positions_shape=None):
141141
activation=self.intermediate_activation,
142142
kernel_initializer=self.kernel_initializer,
143143
bias_initializer=self.bias_initializer,
144+
dtype=self.dtype_policy,
144145
name="intermediate_dense",
145146
)
146147
self._intermediate_layer_norm = keras.layers.LayerNormalization(
147148
epsilon=self.layer_norm_epsilon,
149+
dtype=self.dtype_policy,
148150
name="intermediate_layer_norm",
149151
)
150152
# The gather length does not affect any of our built variables, so
@@ -185,6 +187,7 @@ def call(self, inputs, mask_positions):
185187
outputs = self.token_embedding(x, reverse=True)
186188
else:
187189
outputs = ops.matmul(x, self._kernel)
190+
outputs = ops.cast(outputs, self.compute_dtype)
188191
outputs = outputs + self._bias
189192

190193
# Apply a final activation.

keras_nlp/layers/modeling/reversible_embedding.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ class ReversibleEmbedding(keras.layers.Embedding):
4949
the `embeddings` matrix (see `keras.constraints`).
5050
mask_zero: Boolean, whether or not the input value 0 is a special
5151
"padding" value that should be masked out.
52+
reverse_dtype: The dtype for the reverse projection computation.
53+
For stability, it is usually best to use full precision even when
54+
working with half or mixed precision training.
5255
5356
Call args:
5457
inputs: The tensor inputs to the layer.
@@ -87,6 +90,7 @@ def __init__(
8790
embeddings_regularizer=None,
8891
embeddings_constraint=None,
8992
mask_zero=False,
93+
reverse_dtype="float32",
9094
**kwargs,
9195
):
9296
super().__init__(
@@ -99,6 +103,7 @@ def __init__(
99103
**kwargs,
100104
)
101105
self.tie_weights = tie_weights
106+
self.reverse_dtype = reverse_dtype
102107

103108
def build(self, inputs_shape=None):
104109
super().build(inputs_shape)
@@ -114,12 +119,12 @@ def build(self, inputs_shape=None):
114119
def call(self, inputs, reverse=False):
115120
if reverse:
116121
if self.tie_weights:
117-
reverse_embeddings = ops.transpose(
118-
ops.convert_to_tensor(self.embeddings)
119-
)
122+
kernel = ops.transpose(ops.convert_to_tensor(self.embeddings))
120123
else:
121-
reverse_embeddings = self.reverse_embeddings
122-
return ops.matmul(inputs, reverse_embeddings)
124+
kernel = self.reverse_embeddings
125+
inputs = ops.cast(inputs, self.reverse_dtype)
126+
kernel = ops.cast(kernel, self.reverse_dtype)
127+
return ops.matmul(inputs, kernel)
123128

124129
return super().call(inputs)
125130

@@ -128,6 +133,7 @@ def get_config(self):
128133
config.update(
129134
{
130135
"tie_weights": self.tie_weights,
136+
"reverse_dtype": self.reverse_dtype,
131137
}
132138
)
133139
return config

keras_nlp/layers/modeling/reversible_embedding_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818
from absl.testing import parameterized
1919

20+
from keras_nlp.backend import config
2021
from keras_nlp.backend import keras
2122
from keras_nlp.backend import ops
2223
from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding
@@ -73,3 +74,22 @@ def test_tied_checkpoint_untied_weights(self):
7374

7475
input_data = ops.ones(shape=(4, 10), dtype="int32")
7576
self.assertAllClose(untied_model(input_data), tied_model(input_data))
77+
78+
def test_reverse_dtype(self):
79+
embedding = ReversibleEmbedding(100, 16, reverse_dtype="float32")
80+
input_data = ops.ones(shape=(4, 10, 16))
81+
output_data = embedding(input_data, reverse=True)
82+
self.assertEqual(output_data.shape, (4, 10, 100))
83+
self.assertDTypeEqual(output_data, "float32")
84+
85+
if config.backend() == "torch":
86+
import torch
87+
88+
if not torch.cuda.is_available():
89+
self.skipTest("Torch CPU does not support float16")
90+
91+
embedding = ReversibleEmbedding(100, 16, reverse_dtype="float16")
92+
input_data = ops.ones(shape=(4, 10, 16))
93+
output_data = embedding(input_data, reverse=True)
94+
self.assertEqual(output_data.shape, (4, 10, 100))
95+
self.assertDTypeEqual(output_data, "float16")

keras_nlp/layers/modeling/token_and_position_embedding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,13 @@ def __init__(
9393
self.embeddings_initializer
9494
),
9595
mask_zero=mask_zero,
96+
dtype=self.dtype_policy,
9697
name="token_embedding",
9798
)
9899
self.position_embedding = PositionEmbedding(
99100
sequence_length=sequence_length,
100101
initializer=clone_initializer(self.embeddings_initializer),
102+
dtype=self.dtype_policy,
101103
name="position_embedding",
102104
)
103105
self.supports_masking = self.token_embedding.supports_masking

keras_nlp/layers/modeling/transformer_decoder.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,14 @@ def __init__(
111111
kernel_initializer="glorot_uniform",
112112
bias_initializer="zeros",
113113
normalize_first=False,
114-
name=None,
115114
**kwargs,
116115
):
117116
# Work around for model saving, we need to ensure our model is built
118117
# immediately after restoring from config.
119118
decoder_sequence_shape = kwargs.pop("decoder_sequence_shape", None)
120119
encoder_sequence_shape = kwargs.pop("encoder_sequence_shape", None)
121120

122-
super().__init__(name=name, **kwargs)
121+
super().__init__(**kwargs)
123122
self.intermediate_dim = intermediate_dim
124123
self.num_heads = num_heads
125124
self.dropout = dropout
@@ -160,6 +159,7 @@ def build(
160159
dropout=self.dropout,
161160
kernel_initializer=clone_initializer(self.kernel_initializer),
162161
bias_initializer=clone_initializer(self.bias_initializer),
162+
dtype=self.dtype_policy,
163163
name="self_attention",
164164
)
165165
if hasattr(self._self_attention_layer, "_build_from_signature"):
@@ -174,11 +174,14 @@ def build(
174174
)
175175
self._self_attention_layer_norm = keras.layers.LayerNormalization(
176176
epsilon=self.layer_norm_epsilon,
177+
dtype=self.dtype_policy,
177178
name="self_attention_layer_norm",
178179
)
179180
self._self_attention_layer_norm.build(decoder_sequence_shape)
180181
self._self_attention_dropout = keras.layers.Dropout(
181182
rate=self.dropout,
183+
dtype=self.dtype_policy,
184+
name="self_attention_dropout",
182185
)
183186

184187
# Cross attention layers are optional.
@@ -191,6 +194,7 @@ def build(
191194
dropout=self.dropout,
192195
kernel_initializer=clone_initializer(self.kernel_initializer),
193196
bias_initializer=clone_initializer(self.bias_initializer),
197+
dtype=self.dtype_policy,
194198
name="cross_attention",
195199
)
196200
if hasattr(self._cross_attention_layer, "_build_from_signature"):
@@ -205,11 +209,14 @@ def build(
205209
)
206210
self._cross_attention_layer_norm = keras.layers.LayerNormalization(
207211
epsilon=self.layer_norm_epsilon,
212+
dtype=self.dtype_policy,
208213
name="cross_attention_layer_norm",
209214
)
210215
self._cross_attention_layer_norm.build(encoder_sequence_shape)
211216
self._cross_attention_dropout = keras.layers.Dropout(
212217
rate=self.dropout,
218+
dtype=self.dtype_policy,
219+
name="cross_attention_dropout",
213220
)
214221

215222
# Feedforward layers.
@@ -218,25 +225,30 @@ def build(
218225
activation=self.activation,
219226
kernel_initializer=clone_initializer(self.kernel_initializer),
220227
bias_initializer=clone_initializer(self.bias_initializer),
221-
name="intermediate_dense",
228+
dtype=self.dtype_policy,
229+
name="feedforward_intermediate_dense",
222230
)
223231
self._feedforward_intermediate_dense.build(decoder_sequence_shape)
224232
self._feedforward_output_dense = keras.layers.Dense(
225233
hidden_dim,
226234
kernel_initializer=clone_initializer(self.kernel_initializer),
227235
bias_initializer=clone_initializer(self.bias_initializer),
228-
name="output_dense",
236+
dtype=self.dtype_policy,
237+
name="feedforward_output_dense",
229238
)
230239
intermediate_shape = list(decoder_sequence_shape)
231240
intermediate_shape[-1] = self.intermediate_dim
232241
self._feedforward_output_dense.build(tuple(intermediate_shape))
233242
self._feedforward_layer_norm = keras.layers.LayerNormalization(
234243
epsilon=self.layer_norm_epsilon,
235-
name="output_layer_norm",
244+
dtype=self.dtype_policy,
245+
name="feedforward_layer_norm",
236246
)
237247
self._feedforward_layer_norm.build(decoder_sequence_shape)
238248
self._feedforward_dropout = keras.layers.Dropout(
239249
rate=self.dropout,
250+
dtype=self.dtype_policy,
251+
name="feedforward_dropout",
240252
)
241253
# Create layers based on input shape.
242254
self.built = True

keras_nlp/layers/modeling/transformer_encoder.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,9 @@ def __init__(
9292
kernel_initializer="glorot_uniform",
9393
bias_initializer="zeros",
9494
normalize_first=False,
95-
name=None,
9695
**kwargs,
9796
):
98-
super().__init__(name=name, **kwargs)
97+
super().__init__(**kwargs)
9998
self.intermediate_dim = intermediate_dim
10099
self.num_heads = num_heads
101100
self.dropout = dropout
@@ -125,6 +124,7 @@ def build(self, inputs_shape):
125124
dropout=self.dropout,
126125
kernel_initializer=clone_initializer(self.kernel_initializer),
127126
bias_initializer=clone_initializer(self.bias_initializer),
127+
dtype=self.dtype_policy,
128128
name="self_attention_layer",
129129
)
130130
if hasattr(self._self_attention_layer, "_build_from_signature"):
@@ -139,38 +139,46 @@ def build(self, inputs_shape):
139139
)
140140
self._self_attention_layer_norm = keras.layers.LayerNormalization(
141141
epsilon=self.layer_norm_epsilon,
142+
dtype=self.dtype_policy,
142143
name="self_attention_layer_norm",
143144
)
144145
self._self_attention_layer_norm.build(inputs_shape)
145146
self._self_attention_dropout = keras.layers.Dropout(
146147
rate=self.dropout,
148+
dtype=self.dtype_policy,
147149
name="self_attention_dropout",
148150
)
149151

150152
# Feedforward layers.
151153
self._feedforward_layer_norm = keras.layers.LayerNormalization(
152154
epsilon=self.layer_norm_epsilon,
155+
dtype=self.dtype_policy,
156+
name="feedforward_layer_norm",
153157
)
154158
self._feedforward_layer_norm.build(inputs_shape)
155159
self._feedforward_intermediate_dense = keras.layers.Dense(
156160
self.intermediate_dim,
157161
activation=self.activation,
158162
kernel_initializer=clone_initializer(self.kernel_initializer),
159163
bias_initializer=clone_initializer(self.bias_initializer),
164+
dtype=self.dtype_policy,
160165
name="feedforward_intermediate_dense",
161166
)
162167
self._feedforward_intermediate_dense.build(inputs_shape)
163168
self._feedforward_output_dense = keras.layers.Dense(
164169
hidden_dim,
165170
kernel_initializer=clone_initializer(self.kernel_initializer),
166171
bias_initializer=clone_initializer(self.bias_initializer),
172+
dtype=self.dtype_policy,
167173
name="feedforward_output_dense",
168174
)
169175
intermediate_shape = list(inputs_shape)
170176
intermediate_shape[-1] = self.intermediate_dim
171177
self._feedforward_output_dense.build(tuple(intermediate_shape))
172178
self._feedforward_dropout = keras.layers.Dropout(
173179
rate=self.dropout,
180+
dtype=self.dtype_policy,
181+
name="feedforward_dropout",
174182
)
175183
self.built = True
176184

keras_nlp/samplers/beam_sampler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import tree
1717

1818
from keras_nlp.api_export import keras_nlp_export
19-
from keras_nlp.backend import keras
2019
from keras_nlp.backend import ops
2120
from keras_nlp.samplers.sampler import Sampler
2221
from keras_nlp.samplers.sampler import call_args_docstring
@@ -161,7 +160,7 @@ def body(prompt, cache, index, log_probs):
161160
# Compute the softmax distribution for the next token.
162161
logits, _, cache = next(prompt, cache, index)
163162
vocab_size = ops.shape(logits)[-1]
164-
probs = keras.activations.softmax(logits / self.temperature)
163+
probs = self.compute_probabilities(logits)
165164

166165
# Compute the running log-likelihood of each new candidate.
167166
next_log_probs = ops.log(probs) + log_probs[..., None]

keras_nlp/samplers/contrastive_sampler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import tree
1616

1717
from keras_nlp.api_export import keras_nlp_export
18-
from keras_nlp.backend import keras
1918
from keras_nlp.backend import ops
2019
from keras_nlp.samplers.sampler import Sampler
2120
from keras_nlp.samplers.sampler import call_args_docstring
@@ -131,7 +130,7 @@ def cond(prompt, cache, index, logits, hidden_states):
131130

132131
def body(prompt, cache, index, logits, hidden_states):
133132
# Compute the softmax distribution for the next token.
134-
probabilities = keras.activations.softmax(logits / self.temperature)
133+
probabilities = self.compute_probabilities(logits)
135134

136135
# Replicate for `self.k` times to find the best token in top-k
137136
# candidates.

0 commit comments

Comments
 (0)