-
tl;dr: Is there any way to use with_sharding_constraint inside of a vmapped function to force partitioning of the batch axis I'm tracking down some resource use issues where I think that the partitioner is not sharding dropout masks across the batch axis. My code broadly looks like:
The reason I suspect this is what's going on is that I get OOM errors that it's trying to allocating precisely the amount of ram it would need to allocate if it were not partitioning the batch dimension for the dropout mask. [1] I can try to minimize and share actual code, but before I did, I wanted to ask:
The existing unit tests (https://github.com/google/jax/blob/480efcf0ee13e8c471c0b3e42a582028fcdccd3c/tests/pjit_test.py#L428 ) don't seem to test for anything like this, but maybe I'm just not understanding them. [1] I stupidly lost the logs, but I get errors that it's trying to allocate 3.36GB, and the dropout mask in question is 128 * 25 * 1024 * 1024 which is 3.355 GB... could be a coincidence of course, but the OOM goes away when I remove dropout. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
I got @pschuh to help me think on this one. We recently added an experimental (intentionally not-yet-documented) option to
In general, you can't inspect pjit's sharding specs because they're applied only downstream, at compilation time.
We're fairly certain that xmap will override sharding specs, so it's unclear that it would work here. Maybe @apaszke can confirm or correct.
One way to avoid this question in today's world is indeed to rewrite your model to be batch-polymorphic so that vmap isn't required. We're pretty sure that Flax does that for similar reasons. Maybe @levskaya or @jekbradbury can confirm or correct. And just thinking out loud: this might suggest that we consider enhancing custom batching (#9073) to also involve axis names. But custom batching is work in progress, on our queue to land even with its current scope. (fyi @mattjj) Thanks, I should add! This is useful feedback. |
Beta Was this translation helpful? Give feedback.
I got @pschuh to help me think on this one.
We recently added an experimental (intentionally not-yet-documented) option to
vmap
via the keyword argumentspmd_axis_name
that might be useful here. See #11807. What do you think?In general, you can't inspect pjit's sharding specs because they're applied only downstream, at compilation time.
We're fairly certain that xm…