@@ -225,28 +225,23 @@ def to_json(self) -> bytes:
225225 config .write (b', "internal_scratch_in_bytes": ' )
226226 config .write (str (self .internal_scratch_in_bytes ).encode ("ascii" ))
227227 if self .output_memory_spaces is not None :
228- is_tuple = len ( self . output_memory_spaces ) > 1
228+ config . write ( b', "output_memory_colors": [' )
229229 for i , memory_space in enumerate (self .output_memory_spaces ):
230230 if i :
231231 config .write (b"," )
232- else :
233- config .write (b', "output_memory_space_colors": [' )
234- color = memory_space .color if memory_space is not None else 6
235- config .write (f'{{"color":{ color } ' .encode ("ascii" ))
236- if is_tuple :
237- config .write (f',"shape_index":{ i } ' .encode ("ascii" ))
238- config .write (b"}" )
232+ color = memory_space .color if memory_space is not None else - 1
233+ config .write (str (color ).encode ("ascii" ))
239234 config .write (b"]" )
240235 if self .input_memory_spaces is not None :
241236 comma = False
242- for i , memory_space in enumerate (self .input_memory_spaces ):
243- if memory_space is None :
237+ for i , input_memory_space in enumerate (self .input_memory_spaces ):
238+ if input_memory_space is None :
244239 continue
245- if memory_space is MemorySpace .SMEM :
240+ if input_memory_space is MemorySpace .SMEM :
246241 # TODO(sharadmv): Add support for SMEM (though atm, XLA will not
247242 # page out SMEM arrays).
248243 continue
249- if memory_space not in (
244+ if input_memory_space not in (
250245 MemorySpace .HBM ,
251246 MemorySpace .VMEM ,
252247 MemorySpace .SMEM ,
@@ -259,7 +254,7 @@ def to_json(self) -> bytes:
259254 else :
260255 config .write (b', "input_memory_space_colors": [' )
261256 config .write (
262- f'{{"operand_index":{ i } ,"color":{ memory_space .color } }}'
257+ f'{{"operand_index":{ i } ,"color":{ input_memory_space .color } }}'
263258 .encode ("ascii" )
264259 )
265260 comma = True
0 commit comments