@@ -116,10 +116,11 @@ def __init__(
116116 layer_norm_eps : float = 1e-6 ,
117117 attention_temp : float = 1.0 ,
118118 pre_ln : bool = True ,
119+ dtype : torch .dtype = torch .float32 ,
119120 ):
120121 super ().__init__ ()
121- self .pos_encoder = PositionalEncoding (d_model )
122- self .shared_embedding = nn .Embedding (ntoken , d_model )
122+ self .pos_encoder = PositionalEncoding (d_model , dtype = dtype )
123+ self .shared_embedding = nn .Embedding (ntoken , d_model , dtype = dtype )
123124 self .encoder = Encoder (
124125 d_model ,
125126 nhead ,
@@ -130,6 +131,7 @@ def __init__(
130131 layer_norm_eps ,
131132 attention_temp ,
132133 pre_ln ,
134+ dtype = dtype ,
133135 )
134136 self .decoder = Decoder (
135137 d_model ,
@@ -141,6 +143,7 @@ def __init__(
141143 layer_norm_eps ,
142144 attention_temp ,
143145 pre_ln ,
146+ dtype = dtype ,
144147 )
145148 # Share positional encoding and embedding between encoder and decoder.
146149 self .encoder .pos_encoder = self .pos_encoder
@@ -287,6 +290,7 @@ def __init__(
287290 layer_norm_eps : float = 1e-6 ,
288291 attention_temp : float = 1.0 ,
289292 pre_ln : bool = True ,
293+ dtype : torch .dtype = torch .float32 ,
290294 ):
291295 super ().__init__ ()
292296 self .nhead = nhead
@@ -301,8 +305,11 @@ def __init__(
301305 layer_norm_eps = layer_norm_eps ,
302306 attention_temp = attention_temp ,
303307 pre_ln = pre_ln ,
308+ dtype = dtype ,
309+ )
310+ encoder_norm = (
311+ nn .LayerNorm (d_model , eps = layer_norm_eps , dtype = dtype ) if pre_ln else None
304312 )
305- encoder_norm = nn .LayerNorm (d_model , eps = layer_norm_eps ) if pre_ln else None
306313 self .encoder = TransformerEncoder (encoder_layer , nlayers , encoder_norm )
307314
308315 def forward (
@@ -332,6 +339,7 @@ def __init__(
332339 layer_norm_eps : float = 1e-6 ,
333340 attention_temp : float = 1.0 ,
334341 pre_ln : bool = True ,
342+ dtype : torch .dtype = torch .float32 ,
335343 ):
336344 super ().__init__ ()
337345 self .nhead = nhead
@@ -347,6 +355,7 @@ def __init__(
347355 nlayers ,
348356 attention_temp ,
349357 pre_ln ,
358+ dtype = dtype ,
350359 )
351360
352361 def forward (
@@ -398,13 +407,18 @@ def forward(
398407
399408
400409class PositionalEncoding (nn .Module ):
401- def __init__ (self , d_model : int , max_len : int = 256 ):
410+ def __init__ (
411+ self ,
412+ d_model : int ,
413+ max_len : int = 256 ,
414+ dtype : torch .dtype = torch .float32 ,
415+ ):
402416 super ().__init__ ()
403417
404418 position = torch .arange (max_len ).unsqueeze (1 )
405419 scale_factor = - math .log (10000.0 ) / (d_model // 2 - 1 )
406420 div_term = torch .exp (torch .arange (d_model // 2 ) * scale_factor )
407- pe = torch .zeros (1 , max_len , d_model )
421+ pe = torch .zeros (1 , max_len , d_model , dtype = dtype )
408422 pe [0 , :, : d_model // 2 ] = torch .sin (position * div_term )
409423 pe [0 , :, d_model // 2 : 2 * (d_model // 2 )] = torch .cos (position * div_term )
410424 self .register_buffer ('pe' , pe )
@@ -599,6 +613,7 @@ def __init__(
599613 num_layers ,
600614 attention_temp ,
601615 pre_ln ,
616+ dtype : torch .dtype = torch .float32 ,
602617 ):
603618 super ().__init__ ()
604619 self .layers = nn .ModuleList (
@@ -612,12 +627,15 @@ def __init__(
612627 layer_norm_eps = layer_norm_eps ,
613628 attention_temp = attention_temp ,
614629 pre_ln = pre_ln ,
630+ dtype = dtype ,
615631 )
616632 for _ in range (num_layers )
617633 ]
618634 )
619635 self .num_layers = num_layers
620- self .norm = nn .LayerNorm (d_model , eps = layer_norm_eps ) if pre_ln else None
636+ self .norm = (
637+ nn .LayerNorm (d_model , eps = layer_norm_eps , dtype = dtype ) if pre_ln else None
638+ )
621639
622640 def forward (
623641 self ,
0 commit comments