File tree Expand file tree Collapse file tree 1 file changed +13
-4
lines changed
src/lightning/pytorch/demos Expand file tree Collapse file tree 1 file changed +13
-4
lines changed Original file line number Diff line number Diff line change @@ -54,15 +54,24 @@ def __init__(
5454
5555 self .ninp = ninp
5656 self .vocab_size = vocab_size
57- self .src_mask = None
57+ self .src_mask : Optional [Tensor ] = None
58+
59+ def generate_square_subsequent_mask (self , size : int ) -> Tensor :
60+ """Generate a square mask for the sequence to prevent future tokens from being seen."""
61+ mask = torch .triu (torch .ones (size , size ), diagonal = 1 )
62+ mask = mask .float ().masked_fill (mask == 1 , float ("-inf" )).masked_fill (mask == 0 , 0.0 )
63+ return mask
5864
5965 def forward (self , inputs : Tensor , target : Tensor , mask : Optional [Tensor ] = None ) -> Tensor :
6066 _ , t = inputs .shape
6167
62- # we assume target is already shifted w.r.t. inputs
68+ # Generate source mask to prevent future token leakage
69+ if self .src_mask is None or self .src_mask .size (0 ) != t :
70+ self .src_mask = self .generate_square_subsequent_mask (t ).to (inputs .device )
71+
72+ # Generate target mask if not provided
6373 if mask is None :
64- mask = torch .tril (torch .ones (t , t , device = inputs .device )) == 1
65- mask = mask .float ().masked_fill (mask == 0 , float ("-inf" )).masked_fill (mask == 1 , 0.0 )
74+ mask = self .generate_square_subsequent_mask (t ).to (inputs .device )
6675
6776 src = self .pos_encoder (self .embedding (inputs ) * math .sqrt (self .ninp ))
6877 target = self .pos_encoder (self .embedding (target ) * math .sqrt (self .ninp ))
You can’t perform that action at this time.
0 commit comments