Skip to content

Commit 8406e4b

Browse files
[Cherry-pick] [proton] Add init and final timestamps to chrome trace (#8870) (#719)
Summary: Cherry-picked from upstream OAI repository. Original Commit: ecbb77c Original Author: Srivatsan Ramesh Original Date: 2025-12-02 12:32:13 -0500 Original commit message: ``` [proton] Add init and final timestamps to chrome trace (#8870) Currently the trace timing information are all obained from GPU clock and are scaled assuming a frequency of 1GHz but if the GPU operates at a different frequency the timing information becomes a bit misleading. In this PR, enhanced the Proton instrumentation trace output to include block-level timing metadata in each trace event, enabling better temporal analysis. ``` This PR was automatically cherry-picked from the upstream triton-lang/triton repository. Pull Request resolved: #719 Reviewed By: fywkevin Differential Revision: D88415845 Pulled By: srivatsan-ramesh fbshipit-source-id: 3167a0e90fa06be42bd83ba121d46aad58add751
1 parent 98e0997 commit 8406e4b

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

third_party/proton/common/lib/TraceDataIO/TraceWriter.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ void StreamChromeTraceWriter::writeKernel(json &object,
232232
element["ts"] = static_cast<double>(ts) / freq;
233233
element["dur"] = static_cast<double>(dur) / freq;
234234
json args;
235+
args["Init Time (ns)"] = bt->initTime;
236+
args["Post Final Time (ns)"] = bt->postFinalTime;
235237
args["Finalization Time (ns)"] = bt->postFinalTime - bt->preFinalTime;
236238
args["Frequency (MHz)"] = freq;
237239
element["args"] = args;

third_party/proton/test/test_instrumentation.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,63 @@ def add_kernel(
701701
assert len(warp1_events) == 2
702702

703703

704+
def test_event_args(tmp_path: pathlib.Path):
705+
706+
@triton.jit
707+
def add_kernel(
708+
x_ptr,
709+
y_ptr,
710+
output_ptr,
711+
n_elements,
712+
BLOCK_SIZE: tl.constexpr,
713+
):
714+
with pl.scope("kernel"):
715+
pid = tl.program_id(axis=0)
716+
block_start = pid * BLOCK_SIZE
717+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
718+
mask = offsets < n_elements
719+
x = tl.load(x_ptr + offsets, mask=mask)
720+
y = tl.load(y_ptr + offsets, mask=mask)
721+
output = x + y
722+
tl.store(output_ptr + offsets, output, mask=mask)
723+
724+
size = 256
725+
x = torch.rand(size, device="cuda")
726+
y = torch.rand(size, device="cuda")
727+
temp_file = tmp_path / "test_block_metadata.chrome_trace"
728+
output = torch.empty_like(x)
729+
n_elements = output.numel()
730+
grid = (1, 1, 1)
731+
proton.start(str(temp_file.with_suffix("")), backend="instrumentation", data="trace")
732+
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024, num_warps=2)
733+
proton.finalize()
734+
735+
with open(temp_file, "rb") as f:
736+
data = json.load(f)
737+
events = data["traceEvents"]
738+
739+
# Verify we have events
740+
assert len(events) > 0
741+
742+
# Verify each event has the required metadata in args
743+
for event in events:
744+
assert "args" in event
745+
args = event["args"]
746+
747+
assert "Init Time (ns)" in args
748+
assert "Post Final Time (ns)" in args
749+
assert "Finalization Time (ns)" in args
750+
751+
# Verify timing values are reasonable
752+
init_time = args["Init Time (ns)"]
753+
post_final_time = args["Post Final Time (ns)"]
754+
finalization_time = args["Finalization Time (ns)"]
755+
756+
assert init_time >= 0
757+
assert post_final_time >= 0
758+
assert finalization_time >= 0
759+
760+
704761
def test_threaded_kernel_call(tmp_path: pathlib.Path):
705762

706763
import threading

0 commit comments

Comments
 (0)