Skip to content

Commit fc1b321

Browse files
Reverts f3ac881
PiperOrigin-RevId: 863233898
1 parent f3ac881 commit fc1b321

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

jax/_src/tpu_custom_call.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -225,28 +225,23 @@ def to_json(self) -> bytes:
225225
config.write(b', "internal_scratch_in_bytes": ')
226226
config.write(str(self.internal_scratch_in_bytes).encode("ascii"))
227227
if self.output_memory_spaces is not None:
228-
is_tuple = len(self.output_memory_spaces) > 1
228+
config.write(b', "output_memory_colors": [')
229229
for i, memory_space in enumerate(self.output_memory_spaces):
230230
if i:
231231
config.write(b",")
232-
else:
233-
config.write(b', "output_memory_space_colors": [')
234-
color = memory_space.color if memory_space is not None else 6
235-
config.write(f'{{"color":{color}'.encode("ascii"))
236-
if is_tuple:
237-
config.write(f',"shape_index":{i}'.encode("ascii"))
238-
config.write(b"}")
232+
color = memory_space.color if memory_space is not None else -1
233+
config.write(str(color).encode("ascii"))
239234
config.write(b"]")
240235
if self.input_memory_spaces is not None:
241236
comma = False
242-
for i, memory_space in enumerate(self.input_memory_spaces):
243-
if memory_space is None:
237+
for i, input_memory_space in enumerate(self.input_memory_spaces):
238+
if input_memory_space is None:
244239
continue
245-
if memory_space is MemorySpace.SMEM:
240+
if input_memory_space is MemorySpace.SMEM:
246241
# TODO(sharadmv): Add support for SMEM (though atm, XLA will not
247242
# page out SMEM arrays).
248243
continue
249-
if memory_space not in (
244+
if input_memory_space not in (
250245
MemorySpace.HBM,
251246
MemorySpace.VMEM,
252247
MemorySpace.SMEM,
@@ -259,7 +254,7 @@ def to_json(self) -> bytes:
259254
else:
260255
config.write(b', "input_memory_space_colors": [')
261256
config.write(
262-
f'{{"operand_index":{i},"color":{memory_space.color}}}'
257+
f'{{"operand_index":{i},"color":{input_memory_space.color}}}'
263258
.encode("ascii")
264259
)
265260
comma = True

0 commit comments

Comments
 (0)