@@ -60,24 +60,29 @@ def enable(self, max_events=1_000_000):
6060 "timestamp" : torch .zeros (max_events , dtype = torch .int64 , device = device ),
6161 "address" : torch .zeros (max_events , dtype = torch .int64 , device = device ),
6262 "duration_cycles" : torch .zeros (max_events , dtype = torch .int64 , device = device ),
63+ "op_index" : torch .zeros (max_events , dtype = torch .int32 , device = device ),
64+ "payload_size" : torch .zeros (max_events , dtype = torch .int32 , device = device ),
6365 }
6466
6567 # Atomic counter for event indexing
6668 self .trace_counter = torch .zeros (1 , dtype = torch .int32 , device = device )
69+ # Atomic counter for operation indexing (tracks operation order)
70+ self .op_index_counter = torch .zeros (1 , dtype = torch .int32 , device = device )
6771
6872 self .iris .info (f"Device tracing enabled with max { max_events } events" )
6973
7074 def reset (self ):
7175 """
7276 Reset trace counter to start a new trace capture.
7377
74- Clears the event counter but keeps buffers allocated.
78+ Clears the event counter and operation index counter but keeps buffers allocated.
7579 """
7680 if not self .enabled :
7781 self .iris .warning ("Tracing not enabled. Call tracing.enable() first." )
7882 return
7983
8084 self .trace_counter .zero_ ()
85+ self .op_index_counter .zero_ ()
8186 self .iris .debug ("Trace buffers reset" )
8287
8388 def _collect_system_metadata (self ):
@@ -139,6 +144,8 @@ def _build_trace_events(self, num_events):
139144 "address" : hex (int (self .trace_buffers ["address" ][i ].item ())),
140145 "xcc_id" : xcc_id ,
141146 "cu_id" : cu_id ,
147+ "op_index" : int (self .trace_buffers ["op_index" ][i ].item ()),
148+ "payload_size" : int (self .trace_buffers ["payload_size" ][i ].item ()),
142149 },
143150 }
144151
@@ -198,11 +205,31 @@ def export(self, filename="trace.json", merge=False):
198205 "traceEvents" : trace_events ,
199206 "displayTimeUnit" : "ns" ,
200207 "metadata" : {
201- "schema_version" : "1.0 " ,
208+ "schema_version" : "1.1 " ,
202209 "num_events" : num_events ,
203210 "rank" : self .iris .cur_rank ,
204211 "world_size" : self .iris .num_ranks ,
205212 "time_unit" : "raw cycles (s_memrealtime @ 100MHz)" ,
213+ "fields" : {
214+ "name" : "Event type name (e.g., 'put', 'get', 'load', 'store')" ,
215+ "cat" : "Event category (always 'iris')" ,
216+ "ts" : "Start timestamp in raw cycles" ,
217+ "pid" : "Process ID (current rank)" ,
218+ "tid" : "Thread ID (XCC{id}_CU{id})" ,
219+ "ph" : "Phase: 'X' for complete events, 'i' for instant events" ,
220+ "dur" : "Duration in cycles (only for complete events)" ,
221+ "args" : {
222+ "program_id" : "Triton program ID (block ID)" ,
223+ "pid_m" : "Program ID in M dimension" ,
224+ "pid_n" : "Program ID in N dimension" ,
225+ "target_rank" : "Target rank for the operation" ,
226+ "address" : "Memory address (hex) - min of address block" ,
227+ "xcc_id" : "XCC (chiplet) ID where event occurred" ,
228+ "cu_id" : "Compute Unit ID where event occurred" ,
229+ "op_index" : "Operation index (0, 1, 2, ...) - automatically tracked" ,
230+ "payload_size" : "Payload size in bytes - automatically calculated from mask and datatype" ,
231+ },
232+ },
206233 ** system_metadata ,
207234 },
208235 }
@@ -255,13 +282,33 @@ def export(self, filename="trace.json", merge=False):
255282 "traceEvents" : all_events ,
256283 "displayTimeUnit" : "ns" ,
257284 "metadata" : {
258- "schema_version" : "1.0 " ,
285+ "schema_version" : "1.1 " ,
259286 "total_events" : len (all_events ),
260287 "max_events" : self .max_events ,
261288 "time_unit" : "cycles (s_memrealtime @ 100MHz)" ,
262289 "world_size" : self .iris .num_ranks ,
263290 "timestamp_offset" : min_ts if all_timestamps else 0 ,
264291 "aligned" : "minimum timestamp across all ranks" ,
292+ "fields" : {
293+ "name" : "Event type name (e.g., 'put', 'get', 'load', 'store')" ,
294+ "cat" : "Event category (always 'iris')" ,
295+ "ts" : "Start timestamp in raw cycles" ,
296+ "pid" : "Process ID (current rank)" ,
297+ "tid" : "Thread ID (XCC{id}_CU{id})" ,
298+ "ph" : "Phase: 'X' for complete events, 'i' for instant events" ,
299+ "dur" : "Duration in cycles (only for complete events)" ,
300+ "args" : {
301+ "program_id" : "Triton program ID (block ID)" ,
302+ "pid_m" : "Program ID in M dimension" ,
303+ "pid_n" : "Program ID in N dimension" ,
304+ "target_rank" : "Target rank for the operation" ,
305+ "address" : "Memory address (hex) - min of address block" ,
306+ "xcc_id" : "XCC (chiplet) ID where event occurred" ,
307+ "cu_id" : "Compute Unit ID where event occurred" ,
308+ "op_index" : "Operation index (0, 1, 2, ...) - automatically tracked" ,
309+ "payload_size" : "Payload size in bytes - automatically calculated from mask and datatype" ,
310+ },
311+ },
265312 ** system_metadata ,
266313 },
267314 }
0 commit comments