You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi JAX team! Can I do jax.lax.psum on 3 TPU devices? The following code passes for GPU devices. I didn't find any issues by googling, so just thought I might post a question here :)
importjaximportjax.numpyasjnpfromjaximportpmap# Define a function to be parallelizeddefmy_function(x):
returnjax.lax.psum(x, 'i')
# Define an input arrayx=jnp.arange(3)
print(jax.devices())
parallel_fn=pmap(my_function, devices=jax.devices()[:3], axis_name='i')
result=parallel_fn(x)
print(result)
compiled = dispatch.compile_or_get_cached(
File "/home/costahuang/.cache/pypoetry/virtualenvs/cleanba-pOszUjJ6-py3.9/lib/python3.9/site-packages/jax/_src/dispatch.py", line 1077, in compile_or_get_cached
return backend_compile(backend, serialized_computation, compile_options,
File "/home/costahuang/.cache/pypoetry/virtualenvs/cleanba-pOszUjJ6-py3.9/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/costahuang/.cache/pypoetry/virtualenvs/cleanba-pOszUjJ6-py3.9/lib/python3.9/site-packages/jax/_src/dispatch.py", line 1012, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: Encountered AllReduce on 3 replicas, which is odd and greater than one. This case is not implemented for TPU.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/costahuang/cleanba/test.py", line 14, in <module>
result = parallel_fn(x)
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: Encountered AllReduce on 3 replicas, which is odd and greater than one. This case is not implemented for TPU.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi JAX team! Can I do
jax.lax.psum
on 3 TPU devices? The following code passes for GPU devices. I didn't find any issues by googling, so just thought I might post a question here :)Beta Was this translation helpful? Give feedback.
All reactions