TypeError: dot_general requires contracting dimensions to have the same shape #10947
Answered
by
YouJiacheng
conceptofmind
asked this question in
Q&A
-
Hi, I have recently been adapting some code from PyTorch to Jax. The code runs without conflict and executes properly in PyTorch but I am having difficulty properly implementing it in Jax. Any input would be greatly appreciated. PyTorch:
Jax:
Thank you, Enrico |
Beta Was this translation helpful? Give feedback.
Answered by
YouJiacheng
Jun 2, 2022
Replies: 1 comment 4 replies
-
Hi. Your matrix multiplication shape is: |
Beta Was this translation helpful? Give feedback.
4 replies
Answer selected by
conceptofmind
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi. Your matrix multiplication shape is:
(dim, patch_dim) @ (patch_num, patch_dim)
.use
new_img = rearrange(img, 'b c (h p1) (w p2) -> b (p1 p2 c) (h w)', p1 = patch_height, p2 = patch_width)