Skip to content

cannot match attention layer output to pytorch's one #231

@CarloLucibello

Description

@CarloLucibello

I'm struggling to import the weights from torchvision's ViT to ours. The problem is that the correct map of the attention layers in torch to the one in metalhead seems non-trivial.

using PythonCall, Metalhead
torch = pyimport("torch")

function th2jl(x::Py)
    xj = pyconvert(Array, x.detach().numpy())
    xj = permutedims(xj, ndims(xj):-1:1)
    return xj
end

m = torch.nn.MultiheadAttention(embed_dim=2, num_heads=1, batch_first=true, bias=false, add_bias_kv=false)
# python forward pass
x = torch.randn(1, 3, 2)
y, a = m(x, x, x, need_weights=true)

mj = Metalhead.MHAttention(2, 1, qkv_bias=false)
# copy weights
mj.qkv_layer.weight .=  th2jl(m.in_proj_weight)'    # transpose back since Linear layers in pytorch don't need transpose 
mj.projection.layers[1].weight .=  th2jl(m.out_proj.weight)' 
# julia forward pass
xj = th2jl(x)
yj = mj(xj)

@assert yj  th2jl(y) # false

Probably this is due to the permutations and chunking in our initial projection, possibly we should rearrange them in such a way that the natural weight mapping from pytorch just works.

Pinging @theabhirath for more insights.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions