44
55# fmt: off
66
7- def create_ragged_descriptor (T , block_shape ):
7+ class TensorDescriptorPtr :
8+ def __init__ (self , data_ptr , dtype ):
9+ self ._data_ptr = data_ptr
10+ self .dtype = dtype
11+
12+ def data_ptr (self ):
13+ return self ._data_ptr
14+
15+
16+ def create_ragged_descriptor (T , block_shape , ragged_dim = 0 , write_only = False ):
817 """
918 Given a 2- or 3-dimensional tensor T, this creates a 'ragged descriptor'
1019 which behaves like a concatenation (along the first axis) of subarrays
@@ -18,22 +27,41 @@ def create_ragged_descriptor(T, block_shape):
1827
1928 block_shape = list (block_shape )
2029 tensor_shape = list (T .shape )
30+ rank = len (tensor_shape )
31+
32+ if ragged_dim < 0 :
33+ ragged_dim += rank
2134
22- assert 2 <= len (tensor_shape ) <= 3 , "ragged tensors must have dimension 2 or 3"
23- assert len (tensor_shape ) == len (block_shape ), "block shape must match tensor shape"
35+ assert 0 <= ragged_dim < rank - 1 , "last dimension cannot be ragged"
36+
37+ if write_only :
38+ assert rank <= 4 , "write-only ragged descriptors must have at most 4 dimensions"
39+ else :
40+ assert rank <= 3 , "read-write ragged descriptors must have at most 3 dimensions"
41+
42+ assert len (block_shape ) == rank , "block shape must have same length as tensor shape"
2443
2544 max_int = 0x7fff0000
2645 billion = 0x40000000 # == 2**30
2746
28- assert tensor_shape [0 ] <= billion , "number of rows may not exceed 2**30"
47+ assert tensor_shape [ragged_dim ] <= billion , "number of rows may not exceed 2**30"
48+ tensor_shape [ragged_dim ] = billion
49+ ragged_stride = T .stride (ragged_dim )
2950
3051 # we prepend an extra two dimensions and rely on the fact that pointers
3152 # have 64-bit wraparound semantics:
32- tma_stride = [2 ** 34 - T . stride ( 0 ), T . stride ( 0 ) ] + [T .stride (i ) for i in range (len ( tensor_shape ) )]
33- tma_shape = [max_int , max_int , billion ] + tensor_shape [ 1 :]
53+ tma_stride = [2 ** 34 - ragged_stride , ragged_stride ] + [T .stride (i ) for i in range (rank )]
54+ tma_shape = [max_int , max_int ] + tensor_shape
3455 box_shape = [1 , 1 ] + block_shape
56+ ptr = T .data_ptr ()
3557
36- return TensorDescriptor (T , tma_shape , tma_stride , box_shape )
58+ if write_only :
59+ tma_stride = tma_stride [1 :]
60+ tma_shape = tma_shape [1 :]
61+ box_shape = box_shape [1 :]
62+ ptr = (ptr - billion * ragged_stride * T .element_size ()) % (2 ** 64 )
63+
64+ return TensorDescriptor (TensorDescriptorPtr (ptr , T .dtype ), tma_shape , tma_stride , box_shape )
3765
3866
3967@triton .jit
@@ -50,7 +78,7 @@ def to_ragged_indices(batch_offset, batch_size, row):
5078
5179
5280@triton .jit
53- def load_ragged (TMA , batch_offset , batch_size , coords ):
81+ def load_ragged (TMA , batch_offset , batch_size , coords , ragged_dim : tl . constexpr = 0 ):
5482 """
5583 Read from a subarray T[batch_offset : batch_offset + batch_size] with
5684 hardware bounds-checking, where reading outside the subarray gives zeros.
@@ -59,14 +87,16 @@ def load_ragged(TMA, batch_offset, batch_size, coords):
5987 TMA.load().
6088 """
6189
62- c0 , c1 , c2 = to_ragged_indices (batch_offset , batch_size , coords [0 ])
63- data = TMA .load ([c0 , c1 , c2 ] + coords [1 :])
90+ tl .static_assert (len (TMA .shape ) == len (coords ) + 2 , "TMA must be a read-write ragged descriptor" )
91+
92+ c0 , c1 , c2 = to_ragged_indices (batch_offset , batch_size , coords [ragged_dim ])
93+ data = TMA .load ([c0 , c1 ] + coords [:ragged_dim ] + [c2 ] + coords [ragged_dim + 1 :])
6494 data = tl .reshape (data , data .shape [2 :])
6595 return data
6696
6797
6898@triton .jit
69- def store_ragged (TMA , batch_offset , batch_size , coords , data ):
99+ def store_ragged (TMA , batch_offset , batch_size , coords , data , ragged_dim : tl . constexpr = 0 ):
70100 """
71101 Write to a subarray T[batch_offset : batch_offset + batch_size] with
72102 hardware bounds-checking, where writes outside the subarray are masked
@@ -76,6 +106,18 @@ def store_ragged(TMA, batch_offset, batch_size, coords, data):
76106 TMA.store().
77107 """
78108
79- c0 , c1 , c2 = to_ragged_indices (batch_offset , batch_size , coords [0 ])
80- data = tl .reshape (data , [1 , 1 ] + data .shape )
81- TMA .store ([c0 , c1 , c2 ] + coords [1 :], data )
109+ if len (TMA .shape ) == len (coords ) + 1 :
110+ write_only : tl .constexpr = True
111+ elif len (TMA .shape ) == len (coords ) + 2 :
112+ write_only : tl .constexpr = False
113+ else :
114+ tl .static_assert (False , "TMA must be a ragged descriptor" )
115+
116+ c0 , c1 , c2 = to_ragged_indices (batch_offset , batch_size , coords [ragged_dim ])
117+
118+ if write_only :
119+ data = tl .reshape (data , [1 ] + data .shape )
120+ TMA .store ([c1 ] + coords [:ragged_dim ] + [c2 ] + coords [ragged_dim + 1 :], data )
121+ else :
122+ data = tl .reshape (data , [1 , 1 ] + data .shape )
123+ TMA .store ([c0 , c1 ] + coords [:ragged_dim ] + [c2 ] + coords [ragged_dim + 1 :], data )
0 commit comments