Skip to content

Commit 6af74b2

Browse files
authored
[FRONTEND] Support passing dtype as constexpr for tma load (#4821)
Fixing an compile error like below when passing dtype through kernel arg for `tl._experimental_descriptor_load`: AttributeError: 'constexpr' object has no attribute 'to_ir'
1 parent fe47f98 commit 6af74b2

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

python/test/unit/hopper/test_experimental_tma.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def kernel(Z, desc, SIZE: tl.constexpr, BYVAL_TMA: tl.constexpr):
5757
@triton.jit
5858
def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
5959
M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
60-
BYVAL_TMA: tl.constexpr):
60+
BYVAL_TMA: tl.constexpr, dtype: tl.constexpr):
6161
if not BYVAL_TMA:
6262
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
6363
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
@@ -72,11 +72,11 @@ def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
7272
offs_k = 0
7373
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
7474
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
75-
a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], tl.float16)
76-
b = tl._experimental_descriptor_load(b_desc_ptr, [offs_k, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], tl.float16)
75+
a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype)
76+
b = tl._experimental_descriptor_load(b_desc_ptr, [offs_k, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], dtype)
7777
accumulator = tl.dot(a, b, acc=accumulator)
7878
offs_k += BLOCK_SIZE_K
79-
accumulator = accumulator.to(tl.float16)
79+
accumulator = accumulator.to(dtype)
8080
tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])
8181

8282

@@ -101,7 +101,7 @@ def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tm
101101
desc_c = create_tma_desc_gmem_ptr(C.data_ptr(), [M, N], [BLOCK_M, BLOCK_N], C.element_size())
102102
kernel = matmul_kernel_tma[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1,
103103
1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, BYVAL_TMA=byval_tma,
104-
num_warps=8, num_stages=num_stages)
104+
num_warps=8, num_stages=num_stages, dtype=tl.float16)
105105
ref_out = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16)
106106
torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3)
107107
if BLOCK_M >= 64 and BLOCK_N >= 64:

python/triton/language/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1613,7 +1613,7 @@ def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=
16131613
16141614
This loads a tensor of data based on the descriptor and offsets.
16151615
"""
1616-
type = block_type(dtype, shape)
1616+
type = block_type(_constexpr_to_value(dtype), shape)
16171617
return semantic.descriptor_load(desc_pointer, offsets, "", "", type, _builder)
16181618

16191619

0 commit comments

Comments
 (0)