Skip to content

Commit b495645

Browse files
authored
[AMD] Implement tl.extra.hip.memrealtime for timing (#7282)
Similar to `tl.extra.cuda.globaltimer`, this PR exposes `tl.extra.hip.memrealtime` for AMD GPU. This is useful for measuring the timing information for AMD kernels. Reference: https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna2-shader-instruction-set-architecture.pdf > 7.2.3. S_MEMREALTIME > This instruction reads a 64-bit "real time-counter" and returns the > value into a pair of SGPRS: > SDST and SDST+1. The time value is from a clock for which the > frequency is constant (not affected by power modes or core clock > frequency changes). > Because the instructions can return out-of-order, the only sensible > way to use this counter is to implement S_WAITCNT 0; this imposes > a wait for all data to return from previous SMEMs before continuing.
1 parent 1607e09 commit b495645

File tree

3 files changed

+35
-8
lines changed

3 files changed

+35
-8
lines changed

python/test/unit/language/test_core.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5930,24 +5930,29 @@ def kernel(Out):
59305930

59315931

59325932
def test_globaltimer(device):
5933-
if is_hip():
5934-
pytest.skip("test_globaltimer is not supported in HIP")
59355933
check_cuda_or_hip(device)
59365934

59375935
@triton.jit
5938-
def kernel(Out1, Out2):
5939-
start = tl.extra.cuda.globaltimer()
5936+
def kernel(Out1, Out2, func: tl.constexpr):
5937+
start = func()
59405938
off = tl.arange(0, 128)
59415939
for i in range(10000):
59425940
tl.store(Out1 + off, tl.load(Out1 + off) + 1)
5943-
end = tl.extra.cuda.globaltimer()
5941+
end = func()
59445942
tl.store(Out2, end - start)
59455943

59465944
out1 = to_triton(np.zeros((128, ), dtype=np.int64), device=device)
59475945
out2 = to_triton(np.zeros((1, ), dtype=np.int64), device=device)
5948-
h = kernel[(1, )](out1, out2)
5946+
if is_cuda():
5947+
func = tl.extra.cuda.globaltimer
5948+
else:
5949+
func = tl.extra.hip.memrealtime
5950+
h = kernel[(1, )](out1, out2, func)
59495951
assert out2[0] > 0
5950-
assert h.asm["ptx"].count("%globaltimer") == 2
5952+
if is_cuda():
5953+
assert h.asm["ptx"].count("%globaltimer") == 2
5954+
else:
5955+
assert h.asm["amdgcn"].count("s_memrealtime") == 2
59515956

59525957

59535958
def test_smid(device):
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
from . import libdevice
22

3-
__all__ = ["libdevice"]
3+
from .utils import memrealtime
4+
5+
__all__ = ["libdevice", "memrealtime"]

third_party/amd/language/hip/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from triton.language import core
2+
3+
4+
@core.extern
5+
def memrealtime(_semantic=None):
6+
"""
7+
Returns a 64-bit real time-counter value
8+
"""
9+
return core.inline_asm_elementwise(
10+
"""
11+
s_memrealtime $0
12+
s_waitcnt vmcnt(0)
13+
""",
14+
"=r",
15+
[],
16+
dtype=core.int64,
17+
is_pure=False,
18+
pack=1,
19+
_semantic=_semantic,
20+
)

0 commit comments

Comments
 (0)