Skip to content

Commit 89823c2

Browse files
htyumeta-codesync[bot]
authored andcommitted
[TLX] Add tlx.size_of (#710)
Summary: Introduces a new tlx.size_of() utility function that returns the size in bytes of a given Triton dtype. This helps unify kernels through different dtypes. Pull Request resolved: #710 Reviewed By: dshi7 Differential Revision: D88204044 Pulled By: htyu fbshipit-source-id: 13b9ba652619956808989aca75ab1d48fa78fb53
1 parent d17a5ab commit 89823c2

File tree

6 files changed

+85
-11
lines changed

6 files changed

+85
-11
lines changed

README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,24 @@ Examples: how mbarriers are communicated in warp specialization
221221

222222
Returns the id of the current thread instance along the given `axis`.
223223

224+
- `tlx.dtype_of(v)`
225+
226+
Returns the dtype of a tensor or tensor descriptor.
227+
228+
- `tlx.size_of(dtype)`
229+
230+
Returns the size in bytes of a given Triton dtype. This is useful for dynamically computing memory sizes based on dtype, especially in barrier synchronization code.
231+
232+
Example:
233+
```python
234+
# Instead of hardcoding size values
235+
tlx.barrier_expect_bytes(barrier, 2 * BLOCK_M * BLOCK_K) # Assumes float16
236+
237+
# Use size_of for dtype-aware computation
238+
tlx.barrier_expect_bytes(barrier,
239+
tlx.size_of(tlx.dtype_of(desc)) * BLOCK_M * BLOCK_K)
240+
```
241+
224242
- `tlx.clock64()`
225243

226244
Returns the current 64-bit hardware clock value. E.g,

python/test/unit/language/test_tlx.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1977,6 +1977,34 @@ def kernel(y_ptr, BLOCK_SIZE: tl.constexpr):
19771977
assert kerenl_info.asm["ttir"].count("store") == 1
19781978

19791979

1980+
def test_size_of(device):
1981+
1982+
@triton.jit
1983+
def size_of_kernel(output_ptr):
1984+
# Test size_of for various dtypes
1985+
size_fp32 = tlx.size_of(tl.float32)
1986+
size_fp16 = tlx.size_of(tl.float16)
1987+
size_int32 = tlx.size_of(tl.int32)
1988+
size_int8 = tlx.size_of(tl.int8)
1989+
size_int64 = tlx.size_of(tl.int64)
1990+
1991+
# Store results
1992+
tl.store(output_ptr + 0, size_fp32)
1993+
tl.store(output_ptr + 1, size_fp16)
1994+
tl.store(output_ptr + 2, size_int32)
1995+
tl.store(output_ptr + 3, size_int8)
1996+
tl.store(output_ptr + 4, size_int64)
1997+
1998+
# Expected sizes in bytes
1999+
expected_sizes = torch.tensor([4, 2, 4, 1, 8], dtype=torch.int32, device=device)
2000+
output = torch.zeros(5, dtype=torch.int32, device=device)
2001+
2002+
grid = lambda meta: (1, )
2003+
size_of_kernel[grid](output)
2004+
2005+
torch.testing.assert_close(output, expected_sizes)
2006+
2007+
19802008
@pytest.mark.skipif(not is_blackwell(), reason="Need Blackwell")
19812009
def test_async_dots_blackwell_tmem(device):
19822010
"""

third_party/tlx/language/tlx/__init__.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,26 @@
1515
CLCPipelineContext,
1616
async_token,
1717
)
18-
from .mem_ops import (local_alloc, local_view, remote_view, local_slice, subslice, async_load, async_load_commit_group,
19-
async_load_wait_group, local_load, local_store, local_trans, local_reinterpret, global_alloc,
20-
async_descriptor_load, async_descriptor_store, async_descriptor_store_wait, fence_async_shared,
21-
make_tensor_descriptor)
18+
from .mem_ops import (
19+
local_alloc,
20+
local_view,
21+
remote_view,
22+
local_slice,
23+
subslice,
24+
async_load,
25+
async_load_commit_group,
26+
async_load_wait_group,
27+
local_load,
28+
local_store,
29+
local_trans,
30+
local_reinterpret,
31+
global_alloc,
32+
async_descriptor_load,
33+
async_descriptor_store,
34+
async_descriptor_store_wait,
35+
fence_async_shared,
36+
make_tensor_descriptor,
37+
)
2238
from .barrier import (
2339
alloc_barriers,
2440
barrier_expect_bytes,
@@ -38,6 +54,7 @@
3854
thread_id,
3955
async_task_replica_id,
4056
dtype_of,
57+
size_of,
4158
clock64,
4259
stoch_round,
4360
)
@@ -107,6 +124,7 @@
107124
"thread_id",
108125
"async_task_replica_id",
109126
"dtype_of",
127+
"size_of",
110128
"clock64",
111129
"stoch_round",
112130
# dynamic launcher ops

third_party/tlx/language/tlx/utility.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
def is_hip():
88
target = driver.active.get_current_target()
9-
return target.backend == 'hip'
9+
return target.backend == "hip"
1010

1111

1212
def cuda_parse_arch(arch):
@@ -42,8 +42,9 @@ def thread_id(axis, _semantic=None):
4242
@tl.builtin
4343
def async_task_replica_id(_semantic=None):
4444
from triton.language.extra.tlx.compiler.code_generator import region_replica_id_stack
45-
assert len(region_replica_id_stack
46-
) > 0, "async_task_replica_id must be called inside an async region where the stack must be non-empty"
45+
46+
assert len(region_replica_id_stack) > 0, (
47+
"async_task_replica_id must be called inside an async region where the stack must be non-empty")
4748
return tl.constexpr(region_replica_id_stack[-1])
4849

4950

@@ -63,6 +64,15 @@ def dtype_of(v, _semantic=None) -> tl.dtype:
6364
raise ValueError(f"dtype_of only works on tensors and tensor descriptors, but got {v}")
6465

6566

67+
@tl.builtin
68+
def size_of(dtype: tl.dtype, _semantic=None) -> tl.constexpr:
69+
"""
70+
Returns the size of a given dtype.
71+
"""
72+
assert isinstance(dtype, tl.dtype), f"size_of expects a dtype, but got {type(dtype)}"
73+
return tl.constexpr(dtype.primitive_bitwidth // 8)
74+
75+
6676
@tl.builtin
6777
def clock64(_semantic=None):
6878
"""

third_party/tlx/tutorials/blackwell-grouped-gemm_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ def grouped_matmul_tlx_kernel(
542542
buf, phase = _get_bufidx_phase(accum_cnt, NUM_SMEM_BUFFERS)
543543
tlx.barrier_wait(smem_empty_bars[buf], phase ^ 1)
544544
tlx.barrier_expect_bytes(smem_full_bars[buf],
545-
2 * (BLOCK_SIZE_M + BLOCK_SIZE_N) * BLOCK_SIZE_K) # float16
545+
tlx.size_of(dtype) * (BLOCK_SIZE_M + BLOCK_SIZE_N) * BLOCK_SIZE_K)
546546
tlx.async_descriptor_load(a_desc, buffers_A[buf], [offs_am, kk * BLOCK_SIZE_K],
547547
smem_full_bars[buf])
548548
tlx.async_descriptor_load(b_desc, buffers_B[buf], [kk * BLOCK_SIZE_K, offs_bn],

third_party/tlx/tutorials/hopper-gemm-ws_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,23 +119,23 @@ def matmul_kernel_tlx_ws(a_desc, b_desc, c_desc, #
119119
empty_a_1st = tlx.local_view(bars_empty_a, buf) # mbar
120120
full_a_1st = tlx.local_view(bars_full_a, buf) # mbar
121121
tlx.barrier_wait(bar=empty_a_1st, phase=p) # EmptyBar A1 wait
122-
tlx.barrier_expect_bytes(full_a_1st, BLOCK_M_SPLIT * BK * 2)
122+
tlx.barrier_expect_bytes(full_a_1st, BLOCK_M_SPLIT * BK * tlx.size_of(tlx.dtype_of(a_desc)))
123123
data_a_1st = tlx.local_view(a, buf) # smem data
124124
tlx.async_descriptor_load(a_desc, data_a_1st, [offset_am, offset_k], full_a_1st)
125125

126126
# Async load to b[buf]
127127
empty_b = tlx.local_view(bars_empty_b, buf)
128128
full_b = tlx.local_view(bars_full_b, buf)
129129
tlx.barrier_wait(bar=empty_b, phase=p)
130-
tlx.barrier_expect_bytes(full_b, BN * BK * 2)
130+
tlx.barrier_expect_bytes(full_b, BN * BK * tlx.size_of(tlx.dtype_of(a_desc)))
131131
data_b = tlx.local_view(b, buf)
132132
tlx.async_descriptor_load(b_desc, data_b, [offset_k, offset_bn], full_b)
133133

134134
# Async load to a[buf+NUM_STAGES]
135135
empty_a_2nd = tlx.local_view(bars_empty_a, buf + NUM_STAGES)
136136
full_a_2nd = tlx.local_view(bars_full_a, buf + NUM_STAGES)
137137
tlx.barrier_wait(bar=empty_a_2nd, phase=p)
138-
tlx.barrier_expect_bytes(bar=full_a_2nd, size=BLOCK_M_SPLIT * BK * 2)
138+
tlx.barrier_expect_bytes(bar=full_a_2nd, size=BLOCK_M_SPLIT * BK * tlx.size_of(tlx.dtype_of(a_desc)))
139139
data_a_2nd = tlx.local_view(a, buf + NUM_STAGES) # smem data
140140
tlx.async_descriptor_load(a_desc, data_a_2nd, [offset_am + BLOCK_M_SPLIT, offset_k], full_a_2nd)
141141

0 commit comments

Comments
 (0)