Skip to content

Commit 8e52b2e

Browse files
authored
[Gluon] Add numel and nbytes properties (#7507)
This adds new properties: - shared_memory_descriptor.numel - block_type.numel - block_type.nbytes And I update the attention tutorial to use `tensor_descriptor.block_type.nbytes` when calling `mbarrier.expect`.
1 parent 345c633 commit 8e52b2e

File tree

4 files changed

+16
-15
lines changed

4 files changed

+16
-15
lines changed

python/test/gluon/test_frontend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def shared_memory_kernel(XBLOCK: ttgl.constexpr, YBLOCK: ttgl.constexpr, layout_
113113
layout_b: ttgl.constexpr, smem_layout: ttgl.constexpr):
114114
unused = ttgl.allocate_shared_memory(ttgl.int32, [XBLOCK, YBLOCK], smem_layout)
115115
a = ttgl.full([XBLOCK, YBLOCK], 0, ttgl.int32, layout_a)
116+
tl.static_assert(a.numel == unused.numel)
117+
tl.static_assert(unused.numel == XBLOCK * YBLOCK)
116118
mem = ttgl.allocate_shared_memory(ttgl.int32, a.shape, smem_layout, a)
117119
b = mem.load(layout_b) # noqa: F841
118120
mem.store(a)
@@ -611,7 +613,8 @@ def async_tma_kernel(input_desc, XBLOCK: ttgl.constexpr):
611613
mbarrier.init(bar, count=1)
612614

613615
tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem)
614-
mbarrier.expect(bar, XBLOCK * XBLOCK * ttgl.float16.primitive_bitwidth // 8)
616+
tl.static_assert(input_desc.block_type.nbytes == XBLOCK * XBLOCK * 2)
617+
mbarrier.expect(bar, input_desc.block_type.nbytes)
615618
mbarrier.wait(bar, 0)
616619

617620
mbarrier.invalidate(bar)

python/triton/experimental/gluon/language/_core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
import math
23
from typing import TypeVar, List, TYPE_CHECKING, Tuple
34
from functools import wraps
45

@@ -216,6 +217,10 @@ def shape(self):
216217
def rank(self):
217218
return len(self.shape)
218219

220+
@property
221+
def numel(self) -> int:
222+
return math.prod(self.shape)
223+
219224
@property
220225
def layout(self):
221226
return self.type.layout

python/triton/language/core.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import math
34
from warnings import warn
45
from contextlib import contextmanager
56
from enum import Enum
@@ -752,6 +753,10 @@ def __eq__(self, other) -> bool:
752753
def scalar(self):
753754
return self.element_ty
754755

756+
@property
757+
def nbytes(self):
758+
return self.numel * (self.element_ty.primitive_bitwidth // 8)
759+
755760
def mangle(self) -> str:
756761
elt = self.scalar.mangle()
757762
shape = '_'.join(map(str, self.shape))
@@ -878,10 +883,7 @@ def __init__(self, handle, type: dtype):
878883
self.handle = handle
879884
# Block shape
880885
self.shape = type.shape if type.is_block() else ()
881-
self.numel = 1
882-
for s in self.shape:
883-
self.numel *= s
884-
self.numel = constexpr(self.numel)
886+
self.numel = constexpr(math.prod(self.shape))
885887
self.type = type # Tensor type (can be block_type)
886888
# Following the practice in pytorch, dtype is scalar type
887889
self.dtype = type.scalar

python/tutorials/gluon/01-attention-forward.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -233,18 +233,9 @@ def get_desc_channel(desc, num_buffers: gl.constexpr, num_consumers: gl.constexp
233233
return SharedMemoryChannel.alloc(shape, desc.dtype, layout, num_buffers, num_consumers)
234234

235235

236-
@tl.constexpr_function
237-
def get_load_size_bytes(desc):
238-
size = desc.dtype.primitive_bitwidth // 8
239-
for dim in desc.block_type.shape:
240-
size *= dim
241-
return size
242-
243-
244236
@gluon.jit
245237
def issue_async_tma_load(smem, bar, desc, offset):
246-
size: gl.constexpr = get_load_size_bytes(desc)
247-
mbarrier.expect(bar, size)
238+
mbarrier.expect(bar, desc.block_type.nbytes)
248239
tma.async_copy_global_to_shared(desc, [offset, 0], bar, smem)
249240

250241

0 commit comments

Comments
 (0)