Skip to content

Commit 91f92ed

Browse files
Add op_index and payload_size tracking to device tracing (#424)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 084d6ba commit 91f92ed

File tree

5 files changed

+127
-8
lines changed

5 files changed

+127
-8
lines changed

examples/23_gemm_all_scatter_tracing/gemm_all_scatter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,15 @@ def persistent_gemm_all_scatter(
147147
else:
148148
# Record duration event around remote store (compiles away if tracing=False)
149149
# Pass 2D pointer tensor; record_event_start takes min as representative address
150+
# op_index is automatically tracked internally (0, 1, 2, ...)
151+
# payload_size is automatically calculated from mask
150152
handle = ctx.tracing.record_event_start(
151153
event_id=TraceEvent().put,
152154
target_rank=remote_rank,
153155
address=c_global + global_offset,
154156
pid_m=pid_m,
155157
pid_n=pid_n,
158+
mask=sub_mask,
156159
)
157160

158161
# Use DeviceContext.put for remote stores

iris/iris.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -946,11 +946,14 @@ def get_device_context(self):
946946
self.tracing.trace_buffers["timestamp"].data_ptr(),
947947
self.tracing.trace_buffers["address"].data_ptr(),
948948
self.tracing.trace_buffers["duration_cycles"].data_ptr(),
949+
self.tracing.trace_buffers["op_index"].data_ptr(),
950+
self.tracing.trace_buffers["payload_size"].data_ptr(),
949951
]
950952
context_data += [
951953
1, # trace_enabled = 1 (true)
952954
self.tracing.max_events,
953955
self.tracing.trace_counter.data_ptr(),
956+
self.tracing.op_index_counter.data_ptr(),
954957
] + trace_buffer_ptrs
955958
else:
956959
context_data += [0] # trace_enabled = 0 (false)
@@ -1384,7 +1387,8 @@ def initialize(context_tensor, rank, world_size, tracing: tl.constexpr = False):
13841387
>>>
13851388
>>> # With tracing
13861389
>>> ctx = DeviceContext.initialize(context_tensor, rank, world_size, tracing=True)
1387-
>>> ctx.tracing.record_event_start(event_id=TraceEvent().put, target_rank=1, address=ptr)
1390+
>>> mask = tl.full([64], True, dtype=tl.int1) # Example mask
1391+
>>> ctx.tracing.record_event_start(event_id=TraceEvent().put, target_rank=1, address=ptr, pid_m=0, pid_n=0, mask=mask)
13881392
"""
13891393
# Extract heap bases (from index 2 onwards)
13901394
heap_bases = context_tensor + 2 # Offset pointer to start at heap bases
@@ -1394,12 +1398,14 @@ def initialize(context_tensor, rank, world_size, tracing: tl.constexpr = False):
13941398
trace_info_idx = 2 + world_size + 1 # Skip: cur_rank, num_ranks, heap_bases, trace_enabled flag
13951399
max_events = tl.load(context_tensor + trace_info_idx + 0)
13961400
trace_counter_ptr = tl.load(context_tensor + trace_info_idx + 1)
1401+
op_index_counter_ptr = tl.load(context_tensor + trace_info_idx + 2)
13971402

1398-
# Cast trace_counter_ptr to pointer type
1403+
# Cast counter pointers to pointer type
13991404
trace_counter = tl.cast(trace_counter_ptr, tl.pointer_type(tl.int32))
1405+
op_index_counter = tl.cast(op_index_counter_ptr, tl.pointer_type(tl.int32))
14001406

1401-
# Extract trace buffer pointers (11 buffers)
1402-
base_idx = trace_info_idx + 2
1407+
# Extract trace buffer pointers (13 buffers)
1408+
base_idx = trace_info_idx + 3 # Updated: +3 because we now have op_index_counter
14031409
trace_buf_event_id = tl.cast(tl.load(context_tensor + base_idx + 0), tl.pointer_type(tl.int32))
14041410
trace_buf_pid = tl.cast(tl.load(context_tensor + base_idx + 1), tl.pointer_type(tl.int32))
14051411
trace_buf_pid_m = tl.cast(tl.load(context_tensor + base_idx + 2), tl.pointer_type(tl.int32))
@@ -1411,13 +1417,16 @@ def initialize(context_tensor, rank, world_size, tracing: tl.constexpr = False):
14111417
trace_buf_timestamp = tl.cast(tl.load(context_tensor + base_idx + 8), tl.pointer_type(tl.int64))
14121418
trace_buf_address = tl.cast(tl.load(context_tensor + base_idx + 9), tl.pointer_type(tl.int64))
14131419
trace_buf_duration_cycles = tl.cast(tl.load(context_tensor + base_idx + 10), tl.pointer_type(tl.int64))
1420+
trace_buf_op_index = tl.cast(tl.load(context_tensor + base_idx + 11), tl.pointer_type(tl.int32))
1421+
trace_buf_payload_size = tl.cast(tl.load(context_tensor + base_idx + 12), tl.pointer_type(tl.int32))
14141422

14151423
# Create DeviceTracing instance
14161424
device_tracing = DeviceTracing(
14171425
enabled=tracing,
14181426
rank=rank,
14191427
max_events=max_events,
14201428
counter=trace_counter,
1429+
op_index_counter=op_index_counter,
14211430
buf_event_id=trace_buf_event_id,
14221431
buf_pid=trace_buf_pid,
14231432
buf_pid_m=trace_buf_pid_m,
@@ -1429,6 +1438,8 @@ def initialize(context_tensor, rank, world_size, tracing: tl.constexpr = False):
14291438
buf_timestamp=trace_buf_timestamp,
14301439
buf_address=trace_buf_address,
14311440
buf_duration_cycles=trace_buf_duration_cycles,
1441+
buf_op_index=trace_buf_op_index,
1442+
buf_payload_size=trace_buf_payload_size,
14321443
)
14331444

14341445
return DeviceContext(rank, world_size, heap_bases, device_tracing)
@@ -1442,6 +1453,7 @@ def initialize(context_tensor, rank, world_size, tracing: tl.constexpr = False):
14421453
rank=rank,
14431454
max_events=max_events_zero,
14441455
counter=dummy_ptr_i32,
1456+
op_index_counter=dummy_ptr_i32,
14451457
buf_event_id=dummy_ptr_i32,
14461458
buf_pid=dummy_ptr_i32,
14471459
buf_pid_m=dummy_ptr_i32,
@@ -1453,6 +1465,8 @@ def initialize(context_tensor, rank, world_size, tracing: tl.constexpr = False):
14531465
buf_timestamp=dummy_ptr_i64,
14541466
buf_address=dummy_ptr_i64,
14551467
buf_duration_cycles=dummy_ptr_i64,
1468+
buf_op_index=dummy_ptr_i32,
1469+
buf_payload_size=dummy_ptr_i32,
14561470
)
14571471

14581472
return DeviceContext(rank, world_size, heap_bases, device_tracing)

iris/tracing/core.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,24 +60,29 @@ def enable(self, max_events=1_000_000):
6060
"timestamp": torch.zeros(max_events, dtype=torch.int64, device=device),
6161
"address": torch.zeros(max_events, dtype=torch.int64, device=device),
6262
"duration_cycles": torch.zeros(max_events, dtype=torch.int64, device=device),
63+
"op_index": torch.zeros(max_events, dtype=torch.int32, device=device),
64+
"payload_size": torch.zeros(max_events, dtype=torch.int32, device=device),
6365
}
6466

6567
# Atomic counter for event indexing
6668
self.trace_counter = torch.zeros(1, dtype=torch.int32, device=device)
69+
# Atomic counter for operation indexing (tracks operation order)
70+
self.op_index_counter = torch.zeros(1, dtype=torch.int32, device=device)
6771

6872
self.iris.info(f"Device tracing enabled with max {max_events} events")
6973

7074
def reset(self):
7175
"""
7276
Reset trace counter to start a new trace capture.
7377
74-
Clears the event counter but keeps buffers allocated.
78+
Clears the event counter and operation index counter but keeps buffers allocated.
7579
"""
7680
if not self.enabled:
7781
self.iris.warning("Tracing not enabled. Call tracing.enable() first.")
7882
return
7983

8084
self.trace_counter.zero_()
85+
self.op_index_counter.zero_()
8186
self.iris.debug("Trace buffers reset")
8287

8388
def _collect_system_metadata(self):
@@ -139,6 +144,8 @@ def _build_trace_events(self, num_events):
139144
"address": hex(int(self.trace_buffers["address"][i].item())),
140145
"xcc_id": xcc_id,
141146
"cu_id": cu_id,
147+
"op_index": int(self.trace_buffers["op_index"][i].item()),
148+
"payload_size": int(self.trace_buffers["payload_size"][i].item()),
142149
},
143150
}
144151

@@ -198,11 +205,31 @@ def export(self, filename="trace.json", merge=False):
198205
"traceEvents": trace_events,
199206
"displayTimeUnit": "ns",
200207
"metadata": {
201-
"schema_version": "1.0",
208+
"schema_version": "1.1",
202209
"num_events": num_events,
203210
"rank": self.iris.cur_rank,
204211
"world_size": self.iris.num_ranks,
205212
"time_unit": "raw cycles (s_memrealtime @ 100MHz)",
213+
"fields": {
214+
"name": "Event type name (e.g., 'put', 'get', 'load', 'store')",
215+
"cat": "Event category (always 'iris')",
216+
"ts": "Start timestamp in raw cycles",
217+
"pid": "Process ID (current rank)",
218+
"tid": "Thread ID (XCC{id}_CU{id})",
219+
"ph": "Phase: 'X' for complete events, 'i' for instant events",
220+
"dur": "Duration in cycles (only for complete events)",
221+
"args": {
222+
"program_id": "Triton program ID (block ID)",
223+
"pid_m": "Program ID in M dimension",
224+
"pid_n": "Program ID in N dimension",
225+
"target_rank": "Target rank for the operation",
226+
"address": "Memory address (hex) - min of address block",
227+
"xcc_id": "XCC (chiplet) ID where event occurred",
228+
"cu_id": "Compute Unit ID where event occurred",
229+
"op_index": "Operation index (0, 1, 2, ...) - automatically tracked",
230+
"payload_size": "Payload size in bytes - automatically calculated from mask and datatype",
231+
},
232+
},
206233
**system_metadata,
207234
},
208235
}
@@ -255,13 +282,33 @@ def export(self, filename="trace.json", merge=False):
255282
"traceEvents": all_events,
256283
"displayTimeUnit": "ns",
257284
"metadata": {
258-
"schema_version": "1.0",
285+
"schema_version": "1.1",
259286
"total_events": len(all_events),
260287
"max_events": self.max_events,
261288
"time_unit": "cycles (s_memrealtime @ 100MHz)",
262289
"world_size": self.iris.num_ranks,
263290
"timestamp_offset": min_ts if all_timestamps else 0,
264291
"aligned": "minimum timestamp across all ranks",
292+
"fields": {
293+
"name": "Event type name (e.g., 'put', 'get', 'load', 'store')",
294+
"cat": "Event category (always 'iris')",
295+
"ts": "Start timestamp in raw cycles",
296+
"pid": "Process ID (current rank)",
297+
"tid": "Thread ID (XCC{id}_CU{id})",
298+
"ph": "Phase: 'X' for complete events, 'i' for instant events",
299+
"dur": "Duration in cycles (only for complete events)",
300+
"args": {
301+
"program_id": "Triton program ID (block ID)",
302+
"pid_m": "Program ID in M dimension",
303+
"pid_n": "Program ID in N dimension",
304+
"target_rank": "Target rank for the operation",
305+
"address": "Memory address (hex) - min of address block",
306+
"xcc_id": "XCC (chiplet) ID where event occurred",
307+
"cu_id": "Compute Unit ID where event occurred",
308+
"op_index": "Operation index (0, 1, 2, ...) - automatically tracked",
309+
"payload_size": "Payload size in bytes - automatically calculated from mask and datatype",
310+
},
311+
},
265312
**system_metadata,
266313
},
267314
}

iris/tracing/device.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class _DeviceTracingCls:
2929
enabled: tl.constexpr
3030
rank: tl.constexpr # current rank (from ctx)
3131
max_events: tl.tensor # scalar (0-dim)
32-
counter: tl.tensor # pointer to int32
32+
counter: tl.tensor # pointer to int32 (event counter)
33+
op_index_counter: tl.tensor # pointer to int32 (operation index counter)
3334
buf_event_id: tl.tensor
3435
buf_pid: tl.tensor
3536
buf_pid_m: tl.tensor
@@ -41,13 +42,16 @@ class _DeviceTracingCls:
4142
buf_timestamp: tl.tensor
4243
buf_address: tl.tensor
4344
buf_duration_cycles: tl.tensor
45+
buf_op_index: tl.tensor
46+
buf_payload_size: tl.tensor
4447

4548
def __init__(
4649
self,
4750
enabled,
4851
rank,
4952
max_events,
5053
counter,
54+
op_index_counter,
5155
buf_event_id,
5256
buf_pid,
5357
buf_pid_m,
@@ -59,12 +63,15 @@ def __init__(
5963
buf_timestamp,
6064
buf_address,
6165
buf_duration_cycles,
66+
buf_op_index,
67+
buf_payload_size,
6268
):
6369
"""Construct DeviceTracing (called from DeviceContext.initialize)."""
6470
self.enabled = enabled
6571
self.rank = rank
6672
self.max_events = max_events
6773
self.counter = counter
74+
self.op_index_counter = op_index_counter
6875
self.buf_event_id = buf_event_id
6976
self.buf_pid = buf_pid
7077
self.buf_pid_m = buf_pid_m
@@ -76,6 +83,8 @@ def __init__(
7683
self.buf_timestamp = buf_timestamp
7784
self.buf_address = buf_address
7885
self.buf_duration_cycles = buf_duration_cycles
86+
self.buf_op_index = buf_op_index
87+
self.buf_payload_size = buf_payload_size
7988

8089
@triton.jit
8190
def record_event_start(
@@ -85,18 +94,59 @@ def record_event_start(
8594
address,
8695
pid_m,
8796
pid_n,
97+
mask=None,
8898
):
8999
"""
90100
Record start of a traced operation. Returns a handle for record_event_end.
91101
92102
Only stores when event_idx.item() < max_events (bounds check).
93103
cur_rank is taken from the tracing context (ctx.rank).
104+
op_index is automatically tracked internally (0, 1, 2, ...).
105+
payload_size is automatically calculated from mask and datatype:
106+
- Counts True values in mask to get number of elements
107+
- Infers datatype size from address pointer type
108+
- Multiplies elements * bytes_per_element to get total bytes
109+
If mask is None, payload_size is set to 0 (unknown size).
110+
111+
Args:
112+
event_id: Event type ID (constexpr)
113+
target_rank: Target rank for the operation
114+
address: Memory address(es) - can be 1D or 2D block of pointers.
115+
The element type is inferred from address.type.element_ty
116+
pid_m: Program ID in M dimension
117+
pid_n: Program ID in N dimension
118+
mask: Optional mask tensor (1D or 2D) indicating valid elements.
119+
If provided, payload_size is calculated as:
120+
(count of True values) * (bytes per element from address dtype).
121+
If None, payload_size is set to 0.
94122
"""
95123
if not self.enabled:
96124
# Return dummy handle; record_event_end will no-op (0 < max_events is false when disabled)
97125
return tl.full((), 0, dtype=tl.int32)
98126

99127
event_idx = tl.atomic_add(self.counter, 1)
128+
op_index = tl.atomic_add(self.op_index_counter, 1)
129+
130+
# Calculate payload_size from mask and datatype
131+
if mask is not None:
132+
# Count True values in mask (True=1, False=0, so sum gives count of elements)
133+
mask_i32 = tl.cast(mask, tl.int32)
134+
num_elements = tl.sum(mask_i32)
135+
136+
# Get element type from address pointer and calculate size in bytes
137+
# address can be 1D or 2D block of pointers, all with same element type
138+
# For blocks, use .dtype instead of .type (like in test_atomic_xchg_triton.py)
139+
# address.dtype is the pointer type, address.dtype.element_ty is the element dtype
140+
elem_type = address.dtype.element_ty
141+
# Get size in bytes using primitive_bitwidth (bits / 8 = bytes)
142+
bitwidth = elem_type.primitive_bitwidth
143+
elem_size_bytes = bitwidth // 8
144+
# Calculate total payload size in bytes
145+
payload_size = num_elements * elem_size_bytes
146+
else:
147+
# No mask provided, set to 0 to indicate unknown size
148+
payload_size = tl.full((), 0, dtype=tl.int32)
149+
100150
if event_idx.item() < self.max_events.item():
101151
tl.store(self.buf_event_id + event_idx, event_id)
102152
tl.store(self.buf_pid + event_idx, tl.program_id(0))
@@ -111,6 +161,8 @@ def record_event_start(
111161
addr_i64 = tl.cast(address, tl.int64)
112162
tl.store(self.buf_address + event_idx, tl.min(addr_i64))
113163
tl.store(self.buf_duration_cycles + event_idx, tl.full((), 0, dtype=tl.int64))
164+
tl.store(self.buf_op_index + event_idx, op_index)
165+
tl.store(self.buf_payload_size + event_idx, payload_size)
114166
return event_idx
115167

116168
@triton.jit

tests/unittests/test_device_context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,15 @@ def device_context_tracing_1d_address_kernel(
2525
# 1D block of pointers: dummy_buffer + offsets
2626
offsets = tl.arange(0, BLOCK_SIZE)
2727
address_1d = dummy_buffer + offsets
28+
# Create a simple mask (all True for this test)
29+
mask = tl.full([BLOCK_SIZE], True, dtype=tl.int1)
2830
handle = ctx.tracing.record_event_start(
2931
event_id=TraceEvent().put,
3032
target_rank=(cur_rank + 1) % num_ranks,
3133
address=address_1d,
3234
pid_m=tl.program_id(0),
3335
pid_n=0,
36+
mask=mask,
3437
)
3538
ctx.tracing.record_event_end(handle)
3639

0 commit comments

Comments
 (0)