-
Notifications
You must be signed in to change notification settings - Fork 763
Open
Description
Hi,
I have a quick question with respect to the relative shift operation:
def _rel_shift(self, x, zero_triu=False):
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=1)
x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
x = x_padded[1:].view_as(x)
if zero_triu:
ones = torch.ones((x.size(0), x.size(1)))
x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]
return xIn the transformer-xl paper, Appendix B (https://arxiv.org/pdf/1901.02860.pdf), we see that the upper right triangular of matrix B consists of zeros. In the above code and throughout the model implementation zero_triu == False so that after performing the relative shift, the upper right triangle is not filled with zeros as described in the paper.
In the huggingface implementation of this function, this unused parameter is completely removed (see https://github.com/huggingface/transformers/blob/master/src/transformers/models/transfo_xl/modeling_transfo_xl.py#L275).
Is the upper right triangle masked at a later place no matter what, or why can zero_triu be neglected?
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels