Skip to content

Commit 5fdc77f

Browse files
authored
[t5x] add patch (#229)
1 parent a618c8b commit 5fdc77f

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

example/t5/install_xpu.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ git checkout 6699ad54480a0691c491fa2aa28d8b46daf1a377
66
git apply ../patch/not_exit_before_max_step.patch
77
git apply ../patch/version_time_dlpath.patch
88
git apply ../patch/adjust_flax.patch
9+
git apply ../patch/correct_device_attr.patch
910

1011
ln -s /usr/local/bin/pip /usr/bin/pip
1112
pip uninstall tensorflow-metadata numba cudf -y
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
diff --git a/t5x/partitioning.py b/t5x/partitioning.py
2+
index 20b0abb..b19ecc1 100644
3+
--- a/t5x/partitioning.py
4+
+++ b/t5x/partitioning.py
5+
@@ -78,13 +78,13 @@ def bounds_from_last_device(last_device: jax.Device) -> HardwareMesh:
6+
# Must be passed the device at the highest-coordinate corner of the
7+
# relevant mesh, which is a requirement we know is satisfied by the last
8+
# device in jax.devices().
9+
- if hasattr(last_device, 'coords'):
10+
- x, y, z = last_device.coords
11+
- return x + 1, y + 1, z + 1, last_device.core_on_chip + 1
12+
- else:
13+
+ # if hasattr(last_device, 'coords'):
14+
+ # x, y, z = last_device.coords
15+
+ # return x + 1, y + 1, z + 1, last_device.core_on_chip + 1
16+
+ # else:
17+
# On non-TPU platforms, the "mesh" is hosts x devices per host in order
18+
# to take advantage of faster within-host interconnect.
19+
- return jax.host_count(), jax.local_device_count()
20+
+ return jax.host_count(), jax.local_device_count()
21+
22+
23+
def get_coords(device: jax.Device) -> HardwareMesh:

0 commit comments

Comments
 (0)