-
-
Notifications
You must be signed in to change notification settings - Fork 66
Closed
Description
I'm struggling to import the weights from torchvision's ViT to ours. The problem is that the correct map of the attention layers in torch to the one in metalhead seems non-trivial.
using PythonCall, Metalhead
torch = pyimport("torch")
function th2jl(x::Py)
xj = pyconvert(Array, x.detach().numpy())
xj = permutedims(xj, ndims(xj):-1:1)
return xj
end
m = torch.nn.MultiheadAttention(embed_dim=2, num_heads=1, batch_first=true, bias=false, add_bias_kv=false)
# python forward pass
x = torch.randn(1, 3, 2)
y, a = m(x, x, x, need_weights=true)
mj = Metalhead.MHAttention(2, 1, qkv_bias=false)
# copy weights
mj.qkv_layer.weight .= th2jl(m.in_proj_weight)' # transpose back since Linear layers in pytorch don't need transpose
mj.projection.layers[1].weight .= th2jl(m.out_proj.weight)'
# julia forward pass
xj = th2jl(x)
yj = mj(xj)
@assert yj ≈ th2jl(y) # false
Probably this is due to the permutations and chunking in our initial projection, possibly we should rearrange them in such a way that the natural weight mapping from pytorch just works.
Pinging @theabhirath for more insights.
Metadata
Metadata
Assignees
Labels
No labels