Skip to content
Discussion options

You must be logged in to vote

Hi - thanks for the question! This question seems predicated on an incorrect understanding of what data.shape means within a jit-compiled computation over sharded data. data.shape is the logical shape of the entire array, regardless of its layout. If you want to inspect the sharding of an array at runtime within a JIT-compiled function, you can use

jax.debug.inspect_array_sharding(data, callback=print)

When I add that to your function, I see this:

PositionalSharding([[{CPU 0}]
                    [{CPU 1}]])

This indicates that the array is indeed sharded as you would expect.

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@raresdolga
Comment options

@jakevdp
Comment options

Answer selected by raresdolga
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants