diff --git a/transformer.py b/transformer.py index 12e674a..51cf108 100644 --- a/transformer.py +++ b/transformer.py @@ -68,7 +68,7 @@ def forward(self, x): # mask = mask.round().to(dtype=torch.int64) # masked_indices = torch.zeros_like(z_indices) masked_indices = self.mask_token_id * torch.ones_like(z_indices, device=z_indices.device) - a_indices = mask * z_indices + (~mask) * masked_indices + a_indices = ~mask * z_indices + (mask) * masked_indices a_indices = torch.cat((sos_tokens, a_indices), dim=1) @@ -76,7 +76,7 @@ def forward(self, x): logits = self.transformer(a_indices) - return logits, target + return logits[~mask], target[~mask] def top_k_logits(self, logits, k): v, ix = torch.topk(logits, k)