Skip to content
Discussion options

You must be logged in to vote

Hey David!

This is quite curious! For starters, that XLA decided the memory optimised layout for that tensor was to have a dimension with size 6 at idx 0, thus causing the blowout to 128 to match TPU MXU tile size. To help us look into that, is there any chance you could send through the unoptimised HLO? You can get it by running jit(f).lower(*args).as_text().

I've quickly made a minimal reproduction I've put below. It behaves sensibly when I don't vmap per example inside the microbatch, but when I do it induces a huge all-to-all before the gradients are computed. I'm still looking into exactly what occurs, but it looks quite similar to your issue. Can you check through quickly and see if…

Replies: 1 comment 12 replies

Comment options

You must be logged in to vote
12 replies
@dlwh
Comment options

@AllanYangZhou
Comment options

@dlwh
Comment options

@dlwh
Comment options

@AllanYangZhou
Comment options

Answer selected by dlwh
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants