Skip to content

Commit e1f0175

Browse files
SkafteNickiBorda
authored andcommitted
Minor correction to demo transformer script (#21033)
* fix demo transformer * fix typing (cherry picked from commit bcfa4dd)
1 parent 4c66b9c commit e1f0175

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

src/lightning/pytorch/demos/transformer.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,24 @@ def __init__(
5454

5555
self.ninp = ninp
5656
self.vocab_size = vocab_size
57-
self.src_mask = None
57+
self.src_mask: Optional[Tensor] = None
58+
59+
def generate_square_subsequent_mask(self, size: int) -> Tensor:
60+
"""Generate a square mask for the sequence to prevent future tokens from being seen."""
61+
mask = torch.triu(torch.ones(size, size), diagonal=1)
62+
mask = mask.float().masked_fill(mask == 1, float("-inf")).masked_fill(mask == 0, 0.0)
63+
return mask
5864

5965
def forward(self, inputs: Tensor, target: Tensor, mask: Optional[Tensor] = None) -> Tensor:
6066
_, t = inputs.shape
6167

62-
# we assume target is already shifted w.r.t. inputs
68+
# Generate source mask to prevent future token leakage
69+
if self.src_mask is None or self.src_mask.size(0) != t:
70+
self.src_mask = self.generate_square_subsequent_mask(t).to(inputs.device)
71+
72+
# Generate target mask if not provided
6373
if mask is None:
64-
mask = torch.tril(torch.ones(t, t, device=inputs.device)) == 1
65-
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, 0.0)
74+
mask = self.generate_square_subsequent_mask(t).to(inputs.device)
6675

6776
src = self.pos_encoder(self.embedding(inputs) * math.sqrt(self.ninp))
6877
target = self.pos_encoder(self.embedding(target) * math.sqrt(self.ninp))

0 commit comments

Comments
 (0)