Skip to content

Commit f56adda

Browse files
Kye GomezKye Gomez
authored andcommitted
[CLEANUP]
1 parent c1a43cc commit f56adda

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

bitnet/bit_transformer.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,15 @@ def __init__(
7575
dropout=0.1,
7676
),
7777
)
78+
79+
# Norm
80+
self.norm = nn.LayerNorm(dim)
7881

7982
def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
8083
skip = x
8184
for attn, ffn in zip(self.layers, self.ffn_layers):
8285
x, _ = attn(x, x, x, is_causal=True, *args, **kwargs)
83-
x = x + skip
86+
x = self.norm(x + skip)
8487
x = ffn(x) + x
8588
return x
8689

@@ -117,8 +120,8 @@ def __init__(
117120
dim: int,
118121
depth: int,
119122
num_tokens: int,
120-
heads=8,
121-
ff_mult=4,
123+
heads: int = 8,
124+
ff_mult: int = 4,
122125
):
123126
super().__init__()
124127
self.emb = nn.Embedding(num_tokens, dim)
@@ -132,8 +135,14 @@ def __init__(
132135
dim,
133136
vocab_size=num_tokens,
134137
)
138+
139+
# Norm
140+
self.norm = nn.LayerNorm(dim)
135141

136142
def forward(self, x):
137143
x = self.emb(x)
144+
# Post emb norm
145+
x = self.norm(x)
146+
138147
x = self.transformer(x)
139148
return self.to_logits(x)

0 commit comments

Comments
 (0)