Skip to content

Relative Positional Encoding #138

@LarsHill

Description

@LarsHill

Hi,

I have a quick question with respect to the relative shift operation:

    def _rel_shift(self, x, zero_triu=False):
        zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
                               device=x.device, dtype=x.dtype)
        x_padded = torch.cat([zero_pad, x], dim=1)

        x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])

        x = x_padded[1:].view_as(x)

        if zero_triu:
            ones = torch.ones((x.size(0), x.size(1)))
            x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]

        return x

In the transformer-xl paper, Appendix B (https://arxiv.org/pdf/1901.02860.pdf), we see that the upper right triangular of matrix B consists of zeros. In the above code and throughout the model implementation zero_triu == False so that after performing the relative shift, the upper right triangle is not filled with zeros as described in the paper.
In the huggingface implementation of this function, this unused parameter is completely removed (see https://github.com/huggingface/transformers/blob/master/src/transformers/models/transfo_xl/modeling_transfo_xl.py#L275).

Is the upper right triangle masked at a later place no matter what, or why can zero_triu be neglected?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions