Skip to content

Commit fe83d88

Browse files
Merge pull request jax-ml#24417 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 688159150
2 parents 783285a + 02f65bb commit fe83d88

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

jax/_src/interpreters/pxla.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1868,7 +1868,8 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
18681868
"does not preserve sharded data representations and instead collects "
18691869
"input and output arrays onto a single device. "
18701870
"Consider removing the outer jit unless you know what you're doing. "
1871-
"See https://github.com/jax-ml/jax/issues/2926.")
1871+
"See https://github.com/jax-ml/jax/issues/2926. Or "
1872+
"use jax.experimental.shard_map instead of pmap under jit compilation.")
18721873

18731874
if nreps > xb.device_count(backend):
18741875
raise ValueError(

0 commit comments

Comments
 (0)