Skip to content

Bug Report for self-attention #5219

@OmPatel512

Description

@OmPatel512

Bug Report for https://neetcode.io/problems/self-attention

i have applied self-attention like this but the numbers are different in output.otherwise shape and all is as expected.

class SingleHeadAttention(nn.Module):

def __init__(self, embedding_dim: int, attention_dim: int):
    super().__init__()
    torch.manual_seed(0)
    self.Q = nn.Linear(embedding_dim, attention_dim)
    self.K = nn.Linear(embedding_dim, attention_dim)
    self.V = nn.Linear(embedding_dim, attention_dim)
    self.d_k = torch.sqrt(torch.tensor(attention_dim, dtype = torch.float32))

def forward(self, embedded: TensorType[float]) -> TensorType[float]:
    # Return your answer to 4 decimal places
    Q = self.Q(embedded)
    K = self.K(embedded)
    V = self.V(embedded)
    score = Q.matmul(torch.transpose(K, 1, 2))
    score = score / self.d_k

    L = score.size(-1)
    mask = torch.triu(torch.ones(L, L, dtype = torch.bool), diagonal=1)
    score = score.masked_fill(mask, float('-inf'))
    energy = nn.functional.softmax(score, dim=-1)

    return energy.matmul(V)

----------- OUTPUT ----------------

Your Output:

[[[-0.0370,-2.1133,0.6130,1.0282],[0.0831,-1.5672,0.4055,0.9310]],[[-0.5007,-1.1551,0.6808,0.6887],[-0.2984,-0.2898,0.3125,0.4436]]]
Expected output:

[[[-1.3004,-0.4002,0.3222,0.8869],[-1.1122,-0.1099,0.1266,0.9520]],[[-0.3623,-0.4155,0.5964,-0.0563],[-0.1706,-0.0719,0.2320,0.0715]]]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions