I think in mask_logits, maybe "return inputs + (-1e30) * (1 - mask)" shall be replaced by "return inputs * mask + (-1e30) * (1 - mask)".