diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py index fefa073fbd310..13b5e05adc680 100644 --- a/src/lightning/pytorch/demos/transformer.py +++ b/src/lightning/pytorch/demos/transformer.py @@ -54,15 +54,24 @@ def __init__( self.ninp = ninp self.vocab_size = vocab_size - self.src_mask = None + self.src_mask: Optional[Tensor] = None + + def generate_square_subsequent_mask(self, size: int) -> Tensor: + """Generate a square mask for the sequence to prevent future tokens from being seen.""" + mask = torch.triu(torch.ones(size, size), diagonal=1) + mask = mask.float().masked_fill(mask == 1, float("-inf")).masked_fill(mask == 0, 0.0) + return mask def forward(self, inputs: Tensor, target: Tensor, mask: Optional[Tensor] = None) -> Tensor: _, t = inputs.shape - # we assume target is already shifted w.r.t. inputs + # Generate source mask to prevent future token leakage + if self.src_mask is None or self.src_mask.size(0) != t: + self.src_mask = self.generate_square_subsequent_mask(t).to(inputs.device) + + # Generate target mask if not provided if mask is None: - mask = torch.tril(torch.ones(t, t, device=inputs.device)) == 1 - mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, 0.0) + mask = self.generate_square_subsequent_mask(t).to(inputs.device) src = self.pos_encoder(self.embedding(inputs) * math.sqrt(self.ninp)) target = self.pos_encoder(self.embedding(target) * math.sqrt(self.ninp))