3333_TPU_V4 = 'TPU v4'
3434_TPU_V5_LITE = "TPU v5 lite"
3535_TPU_V5E = "TPU v5e"
36+ _TPU_V5P = "TPU v5p"
3637
3738# Maps physical topology -> mesh shape -> transpose to use for jekbradbury's
3839# famous contiguous mesh trick.
7071_TRAY_4x4_RING_ORDER = (0 , 1 , 2 , 3 , 7 , 6 , 5 , 9 , 10 , 11 , 15 , 14 , 13 , 12 , 8 , 4 )
7172_V5E_TRAY_RING_ORDER = (0 , 1 , 2 , 3 , 7 , 6 , 5 , 4 )
7273_V5E_TRAY_IOTA_ORDER = (0 , 4 , 2 , 6 , 1 , 5 , 3 , 7 )
74+ _V5P_2x2x2_ORDER = (0 , 1 , 3 , 2 , 6 , 7 , 5 , 4 )
7375
7476def _tpu_v2_v3_create_device_mesh (
7577 mesh_shape : Sequence [int ],
@@ -148,6 +150,35 @@ def _v5e_create_device_mesh(
148150 return None
149151
150152
153+ def _v5p_create_device_mesh (
154+ mesh_shape : Sequence [int ], devices : Sequence [Any ], ** unused_kwargs
155+ ) -> np .ndarray | None :
156+ """Creates device assignment for selected topologies.
157+
158+ Args:
159+ mesh_shape: Logical mesh shape used by the model.
160+ devices: TPU devices.
161+ **unused_kwargs: ...
162+
163+ Returns:
164+ None or reordered devices reshaped as `mesh_shape`.
165+ """
166+ max_x , max_y , max_z = max (getattr (d , "coords" , (0 , 0 , 0 )) for d in devices )
167+ bound_x , bound_y , bound_z = max_x + 1 , max_y + 1 , max_z + 1
168+ # Our ring re-ordering makes sense only if the passed-in devices are
169+ # sequential, which may not always be the case. reversed() changes z-minor to
170+ # x-minor.
171+ sequential_devices = sorted (
172+ devices ,
173+ key = lambda d : tuple (reversed (getattr (d , "coords" , (0 , 0 , 0 )))))
174+
175+ if bound_x == bound_y == 2 and bound_z == 2 :
176+ device_mesh = np .asarray (sequential_devices )
177+ device_mesh = device_mesh [np .array (_V5P_2x2x2_ORDER )]
178+ device_mesh = device_mesh .reshape (mesh_shape )
179+ return device_mesh
180+ return None
181+
151182# Registers functions to create device mesh for specific device kinds. Takes
152183# precedence over the more general logic in create_device_mesh(). Handler may
153184# return None; in that case, it will fall back to using the default logic.
@@ -158,6 +189,7 @@ def _v5e_create_device_mesh(
158189 _TPU_V2 : _tpu_v2_v3_create_device_mesh ,
159190 _TPU_V3 : _tpu_v2_v3_create_device_mesh ,
160191 _TPU_V5_LITE : _v5e_create_device_mesh ,
192+ _TPU_V5P : _v5p_create_device_mesh ,
161193}
162194
163195
0 commit comments