Skip to content

Commit f3ac881

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[mosaic] Avoid using the deprecated output_memory_colors in CustomCallBackendConfig
PiperOrigin-RevId: 863192178
1 parent 9489ca7 commit f3ac881

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

jax/_src/tpu_custom_call.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -225,23 +225,28 @@ def to_json(self) -> bytes:
225225
config.write(b', "internal_scratch_in_bytes": ')
226226
config.write(str(self.internal_scratch_in_bytes).encode("ascii"))
227227
if self.output_memory_spaces is not None:
228-
config.write(b', "output_memory_colors": [')
228+
is_tuple = len(self.output_memory_spaces) > 1
229229
for i, memory_space in enumerate(self.output_memory_spaces):
230230
if i:
231231
config.write(b",")
232-
color = memory_space.color if memory_space is not None else -1
233-
config.write(str(color).encode("ascii"))
232+
else:
233+
config.write(b', "output_memory_space_colors": [')
234+
color = memory_space.color if memory_space is not None else 6
235+
config.write(f'{{"color":{color}'.encode("ascii"))
236+
if is_tuple:
237+
config.write(f',"shape_index":{i}'.encode("ascii"))
238+
config.write(b"}")
234239
config.write(b"]")
235240
if self.input_memory_spaces is not None:
236241
comma = False
237-
for i, input_memory_space in enumerate(self.input_memory_spaces):
238-
if input_memory_space is None:
242+
for i, memory_space in enumerate(self.input_memory_spaces):
243+
if memory_space is None:
239244
continue
240-
if input_memory_space is MemorySpace.SMEM:
245+
if memory_space is MemorySpace.SMEM:
241246
# TODO(sharadmv): Add support for SMEM (though atm, XLA will not
242247
# page out SMEM arrays).
243248
continue
244-
if input_memory_space not in (
249+
if memory_space not in (
245250
MemorySpace.HBM,
246251
MemorySpace.VMEM,
247252
MemorySpace.SMEM,
@@ -254,7 +259,7 @@ def to_json(self) -> bytes:
254259
else:
255260
config.write(b', "input_memory_space_colors": [')
256261
config.write(
257-
f'{{"operand_index":{i},"color":{input_memory_space.color}}}'
262+
f'{{"operand_index":{i},"color":{memory_space.color}}}'
258263
.encode("ascii")
259264
)
260265
comma = True

0 commit comments

Comments
 (0)