Does pjit
have partition rules?
#11684
-
Hi, I'm trying to understand what import jax
from jax.config import config
from jax.experimental import maps
from jax.experimental import PartitionSpec
from jax.experimental.pjit import pjit
# [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]
print(jax.devices())
devices = np.asarray(jax.devices()).reshape(2, 1)
mesh = maps.Mesh(devices, ('x', 'y'))
# array([[0, 1, 2, 3],
# [4, 5, 6, 7]])
input_data = np.arange(8).reshape(2, 4)
def dummy_f(x):
return x
f = pjit(
dummy_f,
in_axis_resources=PartitionSpec('x', None),
out_axis_resources=None
)
with maps.Mesh(mesh.devices, ('x', 'y')):
data = f(input_data)
print(data, data.device_buffers) What I expect is: array([[0, 1, 2, 3], =(in_axis)=> device0: [[0, 1, 2, 3]]
[4, 5, 6, 7]]) =(in_axis)=> device1: [[4, 5, 6, 7]]
=(out_axis)=> device0: [[0, 1, 2, 3], [4, 5, 6, 7]]
=(out_axis)=> device1: [[0, 1, 2, 3], [4, 5, 6, 7]] However, I receive an NCCL error:
I have a few concerns:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Which jax version are you using? I can run this code successfully on my end with following stdout (I'm using jax 0.3.15):
|
Beta Was this translation helpful? Give feedback.
Which jax version are you using?
I can run this code successfully on my end with following stdout (I'm using jax 0.3.15):