How can I understand T(8,128)(2,1) and braces in the error information of padding on TPU? #19344
Unanswered
mathpluscode
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I am using JAX with convolutional layers on TPU, and I got the following error message
Here the shape
128,128,12
is the shape of the image and32
is the batch size. I do understand that on TPU we need some paddingBut how can we understand
bf16[32,128,128,12,64]{0,4,3,2,1:T(8,128)(2,1)}
, particularly what does{0,4,3,2,1}
andT(8,128)(2,1)
mean here ? Found https://github.com/openxla/xla/blob/main/docs/tiled_layout.md#repeated-tiling,{0,4,3,2,1}
relates to the physical order of dimensions, so[32,128,128,12,64]
is actually stored in shape[32,64,12,128,128]
. 🤔 I wonder if the tiling of(8,128)
corresponds to the axis[32,64]
or[64,32]
? 🤔 As if it 128 corresponds to 32, that represents the 4x padding, would also explain why it is 4x expansion for both[32,128,128,12,64]
and[32,128,128,12,16]
.T(8,128)(2,1)
means a repeated tiling, first tiled by (8,128) then insider tiled by (2,1), the second tiling should not request padding.Update: modified my descriptions with some guess from the documents, but would still be happy to get a confirmation! Super thanks for the help in advance!
Beta Was this translation helpful? Give feedback.
All reactions