-
I am able to run the multi-host gpu using jax. global devices= [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=1), GpuDevice(id=2, process_index=2)]
local devices= [GpuDevice(id=1, process_index=1)] I wonder how to use the I am trying this example # Run this script on 3 GPU nodes, assuming 10.128.0.6 is the master node
# python nvidia_gpu_pjit.py --server_addr="10.128.0.6:1456" --num_hosts=3 --host_idx=0
# python nvidia_gpu_pjit.py --server_addr="10.128.0.6:1456" --num_hosts=3 --host_idx=1
# python nvidia_gpu_pjit.py --server_addr="10.128.0.6:1456" --num_hosts=3 --host_idx=2
from absl import app
from absl import flags
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.pjit import pjit, PartitionSpec as P
from jax.experimental import maps
flags.DEFINE_string('server_addr', '', help='server ip addr')
flags.DEFINE_integer('num_hosts', 1, help='num of hosts' )
flags.DEFINE_integer('host_idx', 0, help='index of current host' )
FLAGS = flags.FLAGS
import jax.numpy as jnp
from jax import pmap
from jax.lax import pmean
def main(argv):
jax.distributed.initialize(FLAGS.server_addr, FLAGS.num_hosts, FLAGS.host_idx)
print('global devices=', jax.devices())
print('local devices=', jax.local_devices())
print('processing id=', jax.process_index())
out = pmap(lambda x: x ** 2, devices=jax.devices())(jnp.arange(3))
out2 = pmap(lambda x: x ** 2 - pmean(x**4, axis_name='i'), axis_name='i', devices=jax.devices())(jnp.arange(3))
print(out)
print(out2) the error i was getting is local devices= [GpuDevice(id=1, process_index=1)]
processing id= 1
Traceback (most recent call last):
File "test_pmap_dist.py", line 60, in <module>
app.run(main)
File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "test_pmap_dist.py", line 33, in main
out = pmap(lambda x: x ** 2, devices=jax.devices())(jnp.arange(3))
File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/jax/_src/api.py", line 2026, in cache_miss
out_tree, out_flat = f_pmapped_(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/jax/_src/api.py", line 1902, in pmap_f
out = pxla.xla_pmap(
File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/jax/core.py", line 1859, in bind
return map_bind(self, fun, *args, **params)
File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/jax/core.py", line 1891, in map_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/jax/core.py", line 1862, in process
return trace.process_map(self, fun, tracers, params)
File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/jax/core.py", line 680, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 792, in xla_pmap_impl
compiled_fun, fingerprint = parallel_callable(
File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/jax/linear_util.py", line 285, in memoized_fun
ans = call(fun, *args)
File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 823, in parallel_callable
pmap_executable = pmap_computation.compile()
File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1091, in compile
self._executable = PmapExecutable.from_hlo(self._hlo, **self.compile_args)
File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1138, in from_hlo
raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Leading axis size of input to pmapped function must equal the number of local devices passed to pmap. Got axis_size=3, num_local_dev
ices=1.
(Local devices available to pmap: gpu:1) |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
sorry, just to bump it up! |
Beta Was this translation helpful? Give feedback.
-
This is working as intended. When using JAX in multicontroller mode, each process should only pass See https://jax.readthedocs.io/en/latest/multi_process.html Hope that helps! |
Beta Was this translation helpful? Give feedback.
This is working as intended. When using JAX in multicontroller mode, each process should only pass
pmap
arguments for its own local devices. Here, you have one local device in each process, so yourpmap
should receive an array with a leading axis of size 1.See https://jax.readthedocs.io/en/latest/multi_process.html
Hope that helps!