Skip to content

Commit 7038b8d

Browse files
Galaxy-Huskylantiga
authored andcommitted
fix: correct the positional encoding of Transformer in pytorch examples
1 parent b0aa504 commit 7038b8d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/lightning/pytorch/demos/transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def forward(self, x: Tensor) -> Tensor:
8888
# TODO: Could make this a `nn.Parameter` with `requires_grad=False`
8989
self.pe = self._init_pos_encoding(device=x.device)
9090

91-
x = x + self.pe[: x.size(0), :]
91+
x = x + self.pe[:, x.size(1)]
9292
return self.dropout(x)
9393

9494
def _init_pos_encoding(self, device: torch.device) -> Tensor:
@@ -97,7 +97,7 @@ def _init_pos_encoding(self, device: torch.device) -> Tensor:
9797
div_term = torch.exp(torch.arange(0, self.dim, 2, device=device).float() * (-math.log(10000.0) / self.dim))
9898
pe[:, 0::2] = torch.sin(position * div_term)
9999
pe[:, 1::2] = torch.cos(position * div_term)
100-
pe = pe.unsqueeze(0).transpose(0, 1)
100+
pe = pe.unsqueeze(0)
101101
return pe
102102

103103

0 commit comments

Comments
 (0)