We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 783285a + 02f65bb commit fe83d88Copy full SHA for fe83d88
jax/_src/interpreters/pxla.py
@@ -1868,7 +1868,8 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
1868
"does not preserve sharded data representations and instead collects "
1869
"input and output arrays onto a single device. "
1870
"Consider removing the outer jit unless you know what you're doing. "
1871
- "See https://github.com/jax-ml/jax/issues/2926.")
+ "See https://github.com/jax-ml/jax/issues/2926. Or "
1872
+ "use jax.experimental.shard_map instead of pmap under jit compilation.")
1873
1874
if nreps > xb.device_count(backend):
1875
raise ValueError(
0 commit comments