jit paralel compuation #19136
-
Hi, If I understand correctly, when you shard data on two devices the computation should happen in parallel on each of the devices, but when I look at the shape of the data, it does not seem to be split. Reproducing code:
Output:
Desired Output:
Despite I see that the input data is sharded on the first dimension, with half of the elements on the first device, when the function is called all the elements seem to be passed instead of being split. If I used "shard_map" I get the desired behaviour. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Hi - thanks for the question! This question seems predicated on an incorrect understanding of what jax.debug.inspect_array_sharding(data, callback=print) When I add that to your function, I see this:
This indicates that the array is indeed sharded as you would expect. |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question! This question seems predicated on an incorrect understanding of what
data.shape
means within ajit
-compiled computation over sharded data.data.shape
is the logical shape of the entire array, regardless of its layout. If you want to inspect the sharding of an array at runtime within a JIT-compiled function, you can useWhen I add that to your function, I see this:
This indicates that the array is indeed sharded as you would expect.