@@ -111,15 +111,14 @@ def __init__(
111111 kernel_initializer = "glorot_uniform" ,
112112 bias_initializer = "zeros" ,
113113 normalize_first = False ,
114- name = None ,
115114 ** kwargs ,
116115 ):
117116 # Work around for model saving, we need to ensure our model is built
118117 # immediately after restoring from config.
119118 decoder_sequence_shape = kwargs .pop ("decoder_sequence_shape" , None )
120119 encoder_sequence_shape = kwargs .pop ("encoder_sequence_shape" , None )
121120
122- super ().__init__ (name = name , ** kwargs )
121+ super ().__init__ (** kwargs )
123122 self .intermediate_dim = intermediate_dim
124123 self .num_heads = num_heads
125124 self .dropout = dropout
@@ -160,6 +159,7 @@ def build(
160159 dropout = self .dropout ,
161160 kernel_initializer = clone_initializer (self .kernel_initializer ),
162161 bias_initializer = clone_initializer (self .bias_initializer ),
162+ dtype = self .dtype_policy ,
163163 name = "self_attention" ,
164164 )
165165 if hasattr (self ._self_attention_layer , "_build_from_signature" ):
@@ -174,11 +174,14 @@ def build(
174174 )
175175 self ._self_attention_layer_norm = keras .layers .LayerNormalization (
176176 epsilon = self .layer_norm_epsilon ,
177+ dtype = self .dtype_policy ,
177178 name = "self_attention_layer_norm" ,
178179 )
179180 self ._self_attention_layer_norm .build (decoder_sequence_shape )
180181 self ._self_attention_dropout = keras .layers .Dropout (
181182 rate = self .dropout ,
183+ dtype = self .dtype_policy ,
184+ name = "self_attention_dropout" ,
182185 )
183186
184187 # Cross attention layers are optional.
@@ -191,6 +194,7 @@ def build(
191194 dropout = self .dropout ,
192195 kernel_initializer = clone_initializer (self .kernel_initializer ),
193196 bias_initializer = clone_initializer (self .bias_initializer ),
197+ dtype = self .dtype_policy ,
194198 name = "cross_attention" ,
195199 )
196200 if hasattr (self ._cross_attention_layer , "_build_from_signature" ):
@@ -205,11 +209,14 @@ def build(
205209 )
206210 self ._cross_attention_layer_norm = keras .layers .LayerNormalization (
207211 epsilon = self .layer_norm_epsilon ,
212+ dtype = self .dtype_policy ,
208213 name = "cross_attention_layer_norm" ,
209214 )
210215 self ._cross_attention_layer_norm .build (encoder_sequence_shape )
211216 self ._cross_attention_dropout = keras .layers .Dropout (
212217 rate = self .dropout ,
218+ dtype = self .dtype_policy ,
219+ name = "cross_attention_dropout" ,
213220 )
214221
215222 # Feedforward layers.
@@ -218,25 +225,30 @@ def build(
218225 activation = self .activation ,
219226 kernel_initializer = clone_initializer (self .kernel_initializer ),
220227 bias_initializer = clone_initializer (self .bias_initializer ),
221- name = "intermediate_dense" ,
228+ dtype = self .dtype_policy ,
229+ name = "feedforward_intermediate_dense" ,
222230 )
223231 self ._feedforward_intermediate_dense .build (decoder_sequence_shape )
224232 self ._feedforward_output_dense = keras .layers .Dense (
225233 hidden_dim ,
226234 kernel_initializer = clone_initializer (self .kernel_initializer ),
227235 bias_initializer = clone_initializer (self .bias_initializer ),
228- name = "output_dense" ,
236+ dtype = self .dtype_policy ,
237+ name = "feedforward_output_dense" ,
229238 )
230239 intermediate_shape = list (decoder_sequence_shape )
231240 intermediate_shape [- 1 ] = self .intermediate_dim
232241 self ._feedforward_output_dense .build (tuple (intermediate_shape ))
233242 self ._feedforward_layer_norm = keras .layers .LayerNormalization (
234243 epsilon = self .layer_norm_epsilon ,
235- name = "output_layer_norm" ,
244+ dtype = self .dtype_policy ,
245+ name = "feedforward_layer_norm" ,
236246 )
237247 self ._feedforward_layer_norm .build (decoder_sequence_shape )
238248 self ._feedforward_dropout = keras .layers .Dropout (
239249 rate = self .dropout ,
250+ dtype = self .dtype_policy ,
251+ name = "feedforward_dropout" ,
240252 )
241253 # Create layers based on input shape.
242254 self .built = True
0 commit comments