From 1cb048b3255bdfdee1f3376791501bea50cde452 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 5 Aug 2025 09:20:29 +0200 Subject: [PATCH 1/2] fix demo transformer --- src/lightning/pytorch/demos/transformer.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py index fefa073fbd310..7eb095f722dc2 100644 --- a/src/lightning/pytorch/demos/transformer.py +++ b/src/lightning/pytorch/demos/transformer.py @@ -56,13 +56,22 @@ def __init__( self.vocab_size = vocab_size self.src_mask = None + def generate_square_subsequent_mask(self, size: int) -> Tensor: + """Generate a square mask for the sequence to prevent future tokens from being seen.""" + mask = torch.triu(torch.ones(size, size), diagonal=1) + mask = mask.float().masked_fill(mask == 1, float("-inf")).masked_fill(mask == 0, 0.0) + return mask + def forward(self, inputs: Tensor, target: Tensor, mask: Optional[Tensor] = None) -> Tensor: _, t = inputs.shape - # we assume target is already shifted w.r.t. inputs + # Generate source mask to prevent future token leakage + if self.src_mask is None or self.src_mask.size(0) != t: + self.src_mask = self.generate_square_subsequent_mask(t).to(inputs.device) + + # Generate target mask if not provided if mask is None: - mask = torch.tril(torch.ones(t, t, device=inputs.device)) == 1 - mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, 0.0) + mask = self.generate_square_subsequent_mask(t).to(inputs.device) src = self.pos_encoder(self.embedding(inputs) * math.sqrt(self.ninp)) target = self.pos_encoder(self.embedding(target) * math.sqrt(self.ninp)) From 122cf27fd612a78f7dea3e9d10d08a20acf94880 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 7 Aug 2025 09:07:49 +0200 Subject: [PATCH 2/2] fix typing --- src/lightning/pytorch/demos/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py index 7eb095f722dc2..13b5e05adc680 100644 --- a/src/lightning/pytorch/demos/transformer.py +++ b/src/lightning/pytorch/demos/transformer.py @@ -54,7 +54,7 @@ def __init__( self.ninp = ninp self.vocab_size = vocab_size - self.src_mask = None + self.src_mask: Optional[Tensor] = None def generate_square_subsequent_mask(self, size: int) -> Tensor: """Generate a square mask for the sequence to prevent future tokens from being seen."""