Skip to content

Commit 4346d91

Browse files
committed
Fix a typo in cross mlp forwarding.
1 parent 16640d8 commit 4346d91

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

qmb/crossmlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
146146
for layer in self.momentum:
147147
new_emb = layer(emb)
148148
new_emb = new_emb + emb
149-
new_emb = new_emb - new_emb.mean(dim=0, keepdim=True)
149+
emb = new_emb - new_emb.mean(dim=0, keepdim=True)
150150
emb = emb / emb.norm(p=2, dim=1, keepdim=True)
151151
else:
152152
raise ValueError(f"Invalid kind: {self.kind}")

0 commit comments

Comments
 (0)