Skip to content

Commit 0cb53c1

Browse files
ShawnZhongdshi7
authored andcommitted
[AMD] Implement tl.extra.hip.memrealtime for timing (#7282)
Similar to `tl.extra.cuda.globaltimer`, this PR exposes `tl.extra.hip.memrealtime` for AMD GPU. This is useful for measuring the timing information for AMD kernels. Reference: https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/instruction-set-architectures/rdna2-shader-instruction-set-architecture.pdf > 7.2.3. S_MEMREALTIME > This instruction reads a 64-bit "real time-counter" and returns the > value into a pair of SGPRS: > SDST and SDST+1. The time value is from a clock for which the > frequency is constant (not affected by power modes or core clock > frequency changes). > Because the instructions can return out-of-order, the only sensible > way to use this counter is to implement S_WAITCNT 0; this imposes > a wait for all data to return from previous SMEMs before continuing.
1 parent 3317bae commit 0cb53c1

File tree

3 files changed

+35
-8
lines changed

3 files changed

+35
-8
lines changed

python/test/unit/language/test_core.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5962,24 +5962,29 @@ def kernel(Out):
59625962

59635963

59645964
def test_globaltimer(device):
5965-
if is_hip():
5966-
pytest.skip("test_globaltimer is not supported in HIP")
59675965
check_cuda_or_hip(device)
59685966

59695967
@triton.jit
5970-
def kernel(Out1, Out2):
5971-
start = tl.extra.cuda.globaltimer()
5968+
def kernel(Out1, Out2, func: tl.constexpr):
5969+
start = func()
59725970
off = tl.arange(0, 128)
59735971
for i in range(10000):
59745972
tl.store(Out1 + off, tl.load(Out1 + off) + 1)
5975-
end = tl.extra.cuda.globaltimer()
5973+
end = func()
59765974
tl.store(Out2, end - start)
59775975

59785976
out1 = to_triton(np.zeros((128, ), dtype=np.int64), device=device)
59795977
out2 = to_triton(np.zeros((1, ), dtype=np.int64), device=device)
5980-
h = kernel[(1, )](out1, out2)
5978+
if is_cuda():
5979+
func = tl.extra.cuda.globaltimer
5980+
else:
5981+
func = tl.extra.hip.memrealtime
5982+
h = kernel[(1, )](out1, out2, func)
59815983
assert out2[0] > 0
5982-
assert h.asm["ptx"].count("%globaltimer") == 2
5984+
if is_cuda():
5985+
assert h.asm["ptx"].count("%globaltimer") == 2
5986+
else:
5987+
assert h.asm["amdgcn"].count("s_memrealtime") == 2
59835988

59845989

59855990
def test_smid(device):
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
from . import libdevice
22

3-
__all__ = ["libdevice"]
3+
from .utils import memrealtime
4+
5+
__all__ = ["libdevice", "memrealtime"]
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from triton.language import core
2+
3+
4+
@core.extern
5+
def memrealtime(_semantic=None):
6+
"""
7+
Returns a 64-bit real time-counter value
8+
"""
9+
return core.inline_asm_elementwise(
10+
"""
11+
s_memrealtime $0
12+
s_waitcnt vmcnt(0)
13+
""",
14+
"=r",
15+
[],
16+
dtype=core.int64,
17+
is_pure=False,
18+
pack=1,
19+
_semantic=_semantic,
20+
)

0 commit comments

Comments
 (0)