66
77from __future__ import annotations
88
9+ from fairseq2 .data_type import DataType
10+ from fairseq2 .device import Device
911from fairseq2 .models .transformer import (
1012 CausalAttentionBias ,
1113 FeedForwardNetwork ,
@@ -62,12 +64,11 @@ def create_model(self) -> TransformerLanguageModel:
6264 final_proj = self .create_final_projection ()
6365
6466 return TransformerLanguageModel (
65- config .model_dim ,
6667 decoder_frontend ,
6768 decoder ,
6869 final_proj ,
69- config .pad_idx ,
70- config .max_seq_len ,
70+ pad_idx = config .pad_idx ,
71+ max_seq_len = config .max_seq_len ,
7172 )
7273
7374 def create_decoder_frontend (self ) -> TransformerFrontend :
@@ -99,9 +100,11 @@ def create_decoder(self) -> TransformerLMDecoder:
99100
100101 layers .append (layer )
101102
102- layer_norm = self .create_layer_norm ()
103-
104- return StandardTransformerLMDecoder (layers , layer_norm )
103+ return StandardTransformerLMDecoder (
104+ layers ,
105+ norm_order = TransformerNormOrder .PRE ,
106+ layer_norm_factory = self .create_layer_norm ,
107+ )
105108
106109 def create_position_encoder (self ) -> PositionEncoder :
107110 config = self ._config
@@ -151,11 +154,6 @@ def create_ffn(self) -> FeedForwardNetwork:
151154
152155 return StandardFeedForwardNetwork (config .model_dim , config .ffn_inner_dim , bias = True )
153156
154- def create_layer_norm (self ) -> LayerNorm :
155- config = self ._config
156-
157- return StandardLayerNorm (config .model_dim , bias = True )
158-
159157 def create_final_projection (self ) -> Projection :
160158 config = self ._config
161159
@@ -165,3 +163,7 @@ def create_final_projection(self) -> Projection:
165163 bias = False ,
166164 init_fn = init_transformer_final_projection ,
167165 )
166+
167+ @staticmethod
168+ def create_layer_norm (model_dim : int , * , device : Device | None = None , dtype : DataType | None = None ) -> LayerNorm :
169+ return StandardLayerNorm (model_dim , bias = True , device = device , dtype = dtype )
0 commit comments