Skip to content

Commit 54e72d5

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
Add wraparound for 2x2x2 v5p
PiperOrigin-RevId: 695603337
1 parent 38d062d commit 54e72d5

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

jax/_src/mesh_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
_TPU_V4 = 'TPU v4'
3434
_TPU_V5_LITE = "TPU v5 lite"
3535
_TPU_V5E = "TPU v5e"
36+
_TPU_V5P = "TPU v5p"
3637

3738
# Maps physical topology -> mesh shape -> transpose to use for jekbradbury's
3839
# famous contiguous mesh trick.
@@ -70,6 +71,7 @@
7071
_TRAY_4x4_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 9, 10, 11, 15, 14, 13, 12, 8, 4)
7172
_V5E_TRAY_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 4)
7273
_V5E_TRAY_IOTA_ORDER = (0, 4, 2, 6, 1, 5, 3, 7)
74+
_V5P_2x2x2_ORDER = (0, 1, 3, 2, 6, 7, 5, 4)
7375

7476
def _tpu_v2_v3_create_device_mesh(
7577
mesh_shape: Sequence[int],
@@ -148,6 +150,35 @@ def _v5e_create_device_mesh(
148150
return None
149151

150152

153+
def _v5p_create_device_mesh(
154+
mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs
155+
) -> np.ndarray | None:
156+
"""Creates device assignment for selected topologies.
157+
158+
Args:
159+
mesh_shape: Logical mesh shape used by the model.
160+
devices: TPU devices.
161+
**unused_kwargs: ...
162+
163+
Returns:
164+
None or reordered devices reshaped as `mesh_shape`.
165+
"""
166+
max_x, max_y, max_z = max(getattr(d, "coords", (0, 0, 0)) for d in devices)
167+
bound_x, bound_y, bound_z = max_x + 1, max_y + 1, max_z + 1
168+
# Our ring re-ordering makes sense only if the passed-in devices are
169+
# sequential, which may not always be the case. reversed() changes z-minor to
170+
# x-minor.
171+
sequential_devices = sorted(
172+
devices,
173+
key=lambda d: tuple(reversed(getattr(d, "coords", (0, 0, 0)))))
174+
175+
if bound_x == bound_y == 2 and bound_z == 2:
176+
device_mesh = np.asarray(sequential_devices)
177+
device_mesh = device_mesh[np.array(_V5P_2x2x2_ORDER)]
178+
device_mesh = device_mesh.reshape(mesh_shape)
179+
return device_mesh
180+
return None
181+
151182
# Registers functions to create device mesh for specific device kinds. Takes
152183
# precedence over the more general logic in create_device_mesh(). Handler may
153184
# return None; in that case, it will fall back to using the default logic.
@@ -158,6 +189,7 @@ def _v5e_create_device_mesh(
158189
_TPU_V2: _tpu_v2_v3_create_device_mesh,
159190
_TPU_V3: _tpu_v2_v3_create_device_mesh,
160191
_TPU_V5_LITE: _v5e_create_device_mesh,
192+
_TPU_V5P: _v5p_create_device_mesh,
161193
}
162194

163195

0 commit comments

Comments
 (0)