1
+ diff --git a/t5x/partitioning.py b/t5x/partitioning.py
2
+ index 20b0abb..b19ecc1 100644
3
+ --- a/t5x/partitioning.py
4
+ +++ b/t5x/partitioning.py
5
+ @@ -78,13 +78,13 @@ def bounds_from_last_device(last_device: jax.Device) -> HardwareMesh:
6
+ # Must be passed the device at the highest-coordinate corner of the
7
+ # relevant mesh, which is a requirement we know is satisfied by the last
8
+ # device in jax.devices().
9
+ - if hasattr(last_device, 'coords'):
10
+ - x, y, z = last_device.coords
11
+ - return x + 1, y + 1, z + 1, last_device.core_on_chip + 1
12
+ - else:
13
+ + # if hasattr(last_device, 'coords'):
14
+ + # x, y, z = last_device.coords
15
+ + # return x + 1, y + 1, z + 1, last_device.core_on_chip + 1
16
+ + # else:
17
+ # On non-TPU platforms, the "mesh" is hosts x devices per host in order
18
+ # to take advantage of faster within-host interconnect.
19
+ - return jax.host_count(), jax.local_device_count()
20
+ + return jax.host_count(), jax.local_device_count()
21
+
22
+
23
+ def get_coords(device: jax.Device) -> HardwareMesh:
0 commit comments