@@ -225,23 +225,28 @@ def to_json(self) -> bytes:
225225 config .write (b', "internal_scratch_in_bytes": ' )
226226 config .write (str (self .internal_scratch_in_bytes ).encode ("ascii" ))
227227 if self .output_memory_spaces is not None :
228- config . write ( b', "output_memory_colors": [' )
228+ is_tuple = len ( self . output_memory_spaces ) > 1
229229 for i , memory_space in enumerate (self .output_memory_spaces ):
230230 if i :
231231 config .write (b"," )
232- color = memory_space .color if memory_space is not None else - 1
233- config .write (str (color ).encode ("ascii" ))
232+ else :
233+ config .write (b', "output_memory_space_colors": [' )
234+ color = memory_space .color if memory_space is not None else 6
235+ config .write (f'{{"color":{ color } ' .encode ("ascii" ))
236+ if is_tuple :
237+ config .write (f',"shape_index":{ i } ' .encode ("ascii" ))
238+ config .write (b"}" )
234239 config .write (b"]" )
235240 if self .input_memory_spaces is not None :
236241 comma = False
237- for i , input_memory_space in enumerate (self .input_memory_spaces ):
238- if input_memory_space is None :
242+ for i , memory_space in enumerate (self .input_memory_spaces ):
243+ if memory_space is None :
239244 continue
240- if input_memory_space is MemorySpace .SMEM :
245+ if memory_space is MemorySpace .SMEM :
241246 # TODO(sharadmv): Add support for SMEM (though atm, XLA will not
242247 # page out SMEM arrays).
243248 continue
244- if input_memory_space not in (
249+ if memory_space not in (
245250 MemorySpace .HBM ,
246251 MemorySpace .VMEM ,
247252 MemorySpace .SMEM ,
@@ -254,7 +259,7 @@ def to_json(self) -> bytes:
254259 else :
255260 config .write (b', "input_memory_space_colors": [' )
256261 config .write (
257- f'{{"operand_index":{ i } ,"color":{ input_memory_space .color } }}'
262+ f'{{"operand_index":{ i } ,"color":{ memory_space .color } }}'
258263 .encode ("ascii" )
259264 )
260265 comma = True
0 commit comments