Commit e210507
[pmap] Handle edge cases in get_from_first_device for jax_pmap_shmap_merge.
pmap replication/sharding semantics are not consistent, so different users of the acme library do different things. The real solution is to migrate away from pmap entirely, but IIUC, there may not be staffing for this.
Add a robust _unreplicate helper function that handles:
- 0-dimensional scalar arrays (cannot be indexed)
- SingleDeviceSharding or single local device (use simple indexing)
- Fully replicated arrays (return addressable shard data without squeeze)
- Default sharded case (squeeze the leading replicated dimension)
This prevents errors when processing nested structures containing arrays
with different sharding configurations under jax_pmap_shmap_merge mode.
PiperOrigin-RevId: 864435972
Change-Id: I6e151cbe0ed771f62762ec4f2d8e11a972e2a4ac1 parent 05ced10 commit e210507
1 file changed
+15
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
387 | 387 | | |
388 | 388 | | |
389 | 389 | | |
390 | | - | |
391 | | - | |
392 | | - | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
393 | 405 | | |
394 | 406 | | |
395 | 407 | | |
| |||
0 commit comments