Skip to content

Commit ba8b150

Browse files
authored
Added Kernel and Bias Initializers to Encoder and Decoder (#50)
* Added Kernel and Bias Initializer to decoder * Added Initializer to Encoder and Decoder * Added Initializers to Expected Test Config * Added Serialized Version to Config, Added New Test * Fixed Docstring for Encoder and Decoder * Changed initializer import to keras.initializer * Removed Redudant Test From Encoder and Decoder * Changed Default to Glorot Uniform and Zeros * Ensure friendly error if bad arg on layer creation * Ran Black Formatter * Fixed Serialization Bug and Reran Black * Added Additional Tests for Testing Value Error * Keeping Attribute Set. From Const. Arg. Together * New test for Value Error if Invalid Initializer * Ran format and lint * Fixed typo and also lines exceeding max length
1 parent ca22feb commit ba8b150

File tree

4 files changed

+105
-10
lines changed

4 files changed

+105
-10
lines changed

keras_nlp/layers/transformer_decoder.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,19 @@ class TransformerDecoder(keras.layers.Layer):
3939
activation function of feedforward network.
4040
layer_norm_epsilon: float, defaults to 1e-5. The eps value in layer
4141
normalization components.
42+
kernel_initializer: string or tf.keras.initializers initializer,
43+
defaults to "glorot_uniform". The kernel initializer for
44+
the dense and multiheaded attention layers.
45+
bias_initializer: string or tf.keras.initializers initializer,
46+
defaults to "zeros". The bias initializer for
47+
the dense and multiheaded attention layers.
4248
name: string, defaults to None. The name of the layer.
4349
**kwargs: other keyword arguments.
4450
4551
Examples:
4652
```python
4753
# Create a single transformer decoder layer.
48-
decoder = keras_nlp.layer.TransformerDecoder(
54+
decoder = keras_nlp.layers.TransformerDecoder(
4955
intermediate_dim=64, num_heads=8)
5056
5157
# Create a simple model containing the decoder.
@@ -74,15 +80,19 @@ def __init__(
7480
dropout=0,
7581
activation="relu",
7682
layer_norm_epsilon=1e-05,
83+
kernel_initializer="glorot_uniform",
84+
bias_initializer="zeros",
7785
name=None,
7886
**kwargs,
7987
):
8088
super().__init__(name=name, **kwargs)
8189
self.intermediate_dim = intermediate_dim
8290
self.num_heads = num_heads
8391
self.dropout = dropout
84-
self.activation = activation
92+
self.activation = keras.activations.get(activation)
8593
self.layer_norm_epsilon = layer_norm_epsilon
94+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
95+
self.bias_initializer = keras.initializers.get(bias_initializer)
8696
self._built = False
8797

8898
def _build(self, input_shape):
@@ -95,12 +105,16 @@ def _build(self, input_shape):
95105
key_dim=self._attention_head_size,
96106
value_dim=self._attention_head_size,
97107
dropout=self.dropout,
108+
kernel_initializer=self.kernel_initializer,
109+
bias_initializer=self.bias_initializer,
98110
)
99111
self._encoder_decoder_attention_layer = keras.layers.MultiHeadAttention(
100112
num_heads=self.num_heads,
101113
key_dim=self._attention_head_size,
102114
value_dim=feature_size,
103115
dropout=self.dropout,
116+
kernel_initializer=self.kernel_initializer,
117+
bias_initializer=self.bias_initializer,
104118
)
105119

106120
self._decoder_attention_layernorm = keras.layers.LayerNormalization()
@@ -114,11 +128,18 @@ def _build(self, input_shape):
114128
# First dense layer in the feedforward network, which maps input
115129
# feauture size to dimension `self.intermediate_dim`.
116130
self._intermediate_dense = keras.layers.Dense(
117-
self.intermediate_dim, activation=self.activation
131+
self.intermediate_dim,
132+
activation=self.activation,
133+
kernel_initializer=self.kernel_initializer,
134+
bias_initializer=self.bias_initializer,
118135
)
119136
# Second dense layer in the feedforward network, which maps input
120137
# feature size back to the input feature size.
121-
self._output_dense = keras.layers.Dense(feature_size)
138+
self._output_dense = keras.layers.Dense(
139+
feature_size,
140+
kernel_initializer=self.kernel_initializer,
141+
bias_initializer=self.bias_initializer,
142+
)
122143
self._outputdropout = keras.layers.Dropout(rate=self.dropout)
123144

124145
def _add_and_norm(self, input1, input2, norm_layer):
@@ -219,8 +240,14 @@ def get_config(self):
219240
"intermediate_dim": self.intermediate_dim,
220241
"num_heads": self.num_heads,
221242
"dropout": self.dropout,
222-
"activation": self.activation,
243+
"activation": keras.activations.serialize(self.activation),
223244
"layer_norm_epsilon": self.layer_norm_epsilon,
245+
"kernel_initializer": keras.initializers.serialize(
246+
self.kernel_initializer
247+
),
248+
"bias_initializer": keras.initializers.serialize(
249+
self.bias_initializer
250+
),
224251
}
225252
)
226253
return config

keras_nlp/layers/transformer_decoder_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,24 +60,46 @@ def test_get_config_and_from_config(self):
6060
decoder = transformer_decoder.TransformerDecoder(
6161
intermediate_dim=4,
6262
num_heads=2,
63+
kernel_initializer="HeNormal",
64+
bias_initializer="Zeros",
6365
)
66+
6467
config = decoder.get_config()
68+
6569
expected_config_subset = {
6670
"intermediate_dim": 4,
6771
"num_heads": 2,
6872
"dropout": 0,
6973
"activation": "relu",
7074
"layer_norm_epsilon": 1e-05,
75+
"kernel_initializer": keras.initializers.serialize(
76+
keras.initializers.HeNormal()
77+
),
78+
"bias_initializer": keras.initializers.serialize(
79+
keras.initializers.Zeros()
80+
),
7181
}
82+
83+
self.assertEqual(config, {**config, **expected_config_subset})
7284
self.assertEqual(config, {**config, **expected_config_subset})
7385

7486
restored_decoder = transformer_decoder.TransformerDecoder.from_config(
7587
config,
7688
)
89+
7790
self.assertEqual(
7891
restored_decoder.get_config(), {**config, **expected_config_subset}
7992
)
8093

94+
def test_value_error_when_invalid_kernel_inititalizer(self):
95+
with self.assertRaises(ValueError):
96+
transformer_decoder.TransformerDecoder(
97+
intermediate_dim=4,
98+
num_heads=2,
99+
dropout=0.5,
100+
kernel_initializer="Invalid",
101+
)
102+
81103
def test_one_training_step_of_transformer_encoder(self):
82104
class MyModel(keras.Model):
83105
def __init__(self):

keras_nlp/layers/transformer_encoder.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,20 @@ class TransformerEncoder(keras.layers.Layer):
3737
activation function of feedforward network.
3838
layer_norm_epsilon: float, defaults to 1e-5. The epsilon value in layer
3939
normalization components.
40+
kernel_initializer: string or tf.keras.initializers initializer,
41+
defaults to "glorot_uniform". The kernel initializer for
42+
the dense and multiheaded attention layers.
43+
bias_initializer: string or tf.keras.initializers initializer,
44+
defaults to "zeros". The bias initializer for
45+
the dense and multiheaded attention layers.
4046
name: string, defaults to None. The name of the layer.
4147
**kwargs: other keyword arguments.
4248
4349
Examples:
4450
4551
```python
4652
# Create a single transformer encoder layer.
47-
encoder = keras_nlp.layer.TransformerEncoder(
53+
encoder = keras_nlp.layers.TransformerEncoder(
4854
intermediate_dim=64, num_heads=8)
4955
5056
# Create a simple model containing the encoder.
@@ -69,15 +75,19 @@ def __init__(
6975
dropout=0,
7076
activation="relu",
7177
layer_norm_epsilon=1e-05,
78+
kernel_initializer="glorot_uniform",
79+
bias_initializer="zeros",
7280
name=None,
7381
**kwargs
7482
):
7583
super().__init__(name=name, **kwargs)
7684
self.intermediate_dim = intermediate_dim
7785
self.num_heads = num_heads
7886
self.dropout = dropout
79-
self.activation = activation
87+
self.activation = keras.activations.get(activation)
8088
self.layer_norm_epsilon = layer_norm_epsilon
89+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
90+
self.bias_initializer = keras.initializers.get(bias_initializer)
8191
self._built = False
8292

8393
def _build(self, input_shape):
@@ -90,6 +100,8 @@ def _build(self, input_shape):
90100
key_dim=self._attention_head_size,
91101
value_dim=self._attention_head_size,
92102
dropout=self.dropout,
103+
kernel_initializer=self.kernel_initializer,
104+
bias_initializer=self.bias_initializer,
93105
)
94106

95107
self._attention_layernorm = keras.layers.LayerNormalization()
@@ -98,9 +110,16 @@ def _build(self, input_shape):
98110
self._attentiondropout = keras.layers.Dropout(rate=self.dropout)
99111

100112
self._intermediate_dense = keras.layers.Dense(
101-
self.intermediate_dim, activation=self.activation
113+
self.intermediate_dim,
114+
activation=self.activation,
115+
kernel_initializer=self.kernel_initializer,
116+
bias_initializer=self.bias_initializer,
117+
)
118+
self._output_dense = keras.layers.Dense(
119+
feature_size,
120+
kernel_initializer=self.kernel_initializer,
121+
bias_initializer=self.bias_initializer,
102122
)
103-
self._output_dense = keras.layers.Dense(feature_size)
104123
self._outputdropout = keras.layers.Dropout(rate=self.dropout)
105124

106125
def _add_and_norm(self, input1, input2, norm_layer):
@@ -161,8 +180,14 @@ def get_config(self):
161180
"intermediate_dim": self.intermediate_dim,
162181
"num_heads": self.num_heads,
163182
"dropout": self.dropout,
164-
"activation": self.activation,
183+
"activation": keras.activations.serialize(self.activation),
165184
"layer_norm_epsilon": self.layer_norm_epsilon,
185+
"kernel_initializer": keras.initializers.serialize(
186+
self.kernel_initializer
187+
),
188+
"bias_initializer": keras.initializers.serialize(
189+
self.bias_initializer
190+
),
166191
}
167192
)
168193
return config

keras_nlp/layers/transformer_encoder_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,45 @@ def test_get_config_and_from_config(self):
5050
encoder = transformer_encoder.TransformerEncoder(
5151
intermediate_dim=4,
5252
num_heads=2,
53+
kernel_initializer="HeNormal",
54+
bias_initializer="Zeros",
5355
)
56+
5457
config = encoder.get_config()
58+
5559
expected_config_subset = {
5660
"intermediate_dim": 4,
5761
"num_heads": 2,
5862
"dropout": 0,
5963
"activation": "relu",
6064
"layer_norm_epsilon": 1e-05,
65+
"kernel_initializer": keras.initializers.serialize(
66+
keras.initializers.HeNormal()
67+
),
68+
"bias_initializer": keras.initializers.serialize(
69+
keras.initializers.Zeros()
70+
),
6171
}
72+
6273
self.assertEqual(config, {**config, **expected_config_subset})
6374

6475
restored_encoder = transformer_encoder.TransformerEncoder.from_config(
6576
config,
6677
)
78+
6779
self.assertEqual(
6880
restored_encoder.get_config(), {**config, **expected_config_subset}
6981
)
7082

83+
def test_value_error_when_invalid_kernel_inititalizer(self):
84+
with self.assertRaises(ValueError):
85+
transformer_encoder.TransformerEncoder(
86+
intermediate_dim=4,
87+
num_heads=2,
88+
dropout=0.5,
89+
kernel_initializer="Invalid",
90+
)
91+
7192
def test_one_training_step_of_transformer_encoder(self):
7293
encoder = transformer_encoder.TransformerEncoder(
7394
intermediate_dim=4,

0 commit comments

Comments
 (0)