Skip to content

Commit fee018e

Browse files
committed
updated factory opt
1 parent fcb0f3f commit fee018e

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

src/fairseq2/models/opt/_factory.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from __future__ import annotations
88

9+
from fairseq2.data_type import DataType
10+
from fairseq2.device import Device
911
from fairseq2.models.transformer import (
1012
CausalAttentionBias,
1113
FeedForwardNetwork,
@@ -62,12 +64,11 @@ def create_model(self) -> TransformerLanguageModel:
6264
final_proj = self.create_final_projection()
6365

6466
return TransformerLanguageModel(
65-
config.model_dim,
6667
decoder_frontend,
6768
decoder,
6869
final_proj,
69-
config.pad_idx,
70-
config.max_seq_len,
70+
pad_idx=config.pad_idx,
71+
max_seq_len=config.max_seq_len,
7172
)
7273

7374
def create_decoder_frontend(self) -> TransformerFrontend:
@@ -99,9 +100,11 @@ def create_decoder(self) -> TransformerLMDecoder:
99100

100101
layers.append(layer)
101102

102-
layer_norm = self.create_layer_norm()
103-
104-
return StandardTransformerLMDecoder(layers, layer_norm)
103+
return StandardTransformerLMDecoder(
104+
layers,
105+
norm_order=TransformerNormOrder.PRE,
106+
layer_norm_factory=self.create_layer_norm,
107+
)
105108

106109
def create_position_encoder(self) -> PositionEncoder:
107110
config = self._config
@@ -151,11 +154,6 @@ def create_ffn(self) -> FeedForwardNetwork:
151154

152155
return StandardFeedForwardNetwork(config.model_dim, config.ffn_inner_dim, bias=True)
153156

154-
def create_layer_norm(self) -> LayerNorm:
155-
config = self._config
156-
157-
return StandardLayerNorm(config.model_dim, bias=True)
158-
159157
def create_final_projection(self) -> Projection:
160158
config = self._config
161159

@@ -165,3 +163,7 @@ def create_final_projection(self) -> Projection:
165163
bias=False,
166164
init_fn=init_transformer_final_projection,
167165
)
166+
167+
@staticmethod
168+
def create_layer_norm(model_dim: int, *, device: Device | None = None, dtype: DataType | None = None) -> LayerNorm:
169+
return StandardLayerNorm(model_dim, bias=True, device=device, dtype=dtype)

0 commit comments

Comments
 (0)