diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index cc4209aef5..3ccc40ef55 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6169,7 +6169,6 @@ def kernel(Out): def test_globaltimer(device): - check_cuda_or_hip(device) if is_hip(): pytest.skip("test_globaltimer is flaky on AMD GPUs") @@ -6195,6 +6194,8 @@ def kernel(Out1, Out2, func: tl.constexpr): assert out2[1] - out2[0] > 0 if is_cuda(): assert h.asm["ptx"].count("%globaltimer") == 2 + elif is_xpu(): + assert h.asm["llir"].count("%tsc") == 2 else: target_arch = triton.runtime.driver.active.get_current_target().arch if "gfx11" in target_arch or "gfx12" in target_arch: @@ -6206,16 +6207,23 @@ def kernel(Out1, Out2, func: tl.constexpr): def test_smid(device): if is_hip(): pytest.skip("test_smid is not supported in HIP") - check_cuda_or_hip(device) @triton.jit - def kernel(Out): - tl.store(Out + tl.program_id(0), tl.extra.intel.smid()) + def kernel(Out, func: tl.constexpr): + tl.store(Out + tl.program_id(0), func()) + + if is_cuda(): + func = tl.extra.cuda.smid + elif is_xpu(): + func = tl.extra.intel.smid out = to_triton(np.zeros((1024, ), dtype=np.int32), device=device) - h = kernel[(out.shape[0], )](out) + h = kernel[(out.shape[0], )](out, func) assert out.sort()[0].unique().shape[0] > 0 - assert h.asm["ptx"].count("%smid") == 1 + if is_cuda(): + assert h.asm["ptx"].count("%smid") == 1 + elif is_xpu(): + assert h.asm["llir"].count("%sr0") == 1 # ----------------------- diff --git a/third_party/intel/language/intel/utils.py b/third_party/intel/language/intel/utils.py index d8eb002735..967cfdae59 100644 --- a/third_party/intel/language/intel/utils.py +++ b/third_party/intel/language/intel/utils.py @@ -3,14 +3,20 @@ @core.extern def globaltimer(_semantic=None): - return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1, - _semantic=_semantic) + return core.inline_asm_elementwise( + """{\n .decl globaltimer v_type=G type=ud num_elts=2 align=qword alias=<$0, 0> \n""" + """ mov (M1_NM, 2) globaltimer(0, 0)<1> %tsc(0,0)<1;1,0> \n}""", "=rw.u", [], dtype=core.uint64, is_pure=False, + pack=1, _semantic=_semantic) @core.extern def smid(_semantic=None): - return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1, - _semantic=_semantic) + sr = core.inline_asm_elementwise("mov (M1_NM, 1) $0(0, 0)<1> %sr0(0,0)<0;1,0>", "=rw.u", [], dtype=core.uint32, + is_pure=True, pack=1, _semantic=_semantic) + pos: core.constexpr = core.constexpr(9) + subslice_mask: core.constexpr = core.constexpr((1 << 11) - 1) + return sr.__and__(subslice_mask, _semantic=_semantic).__rshift__(pos, _semantic=_semantic) + @core.builtin