24
24
# ===-----------------------------------------------------------------------===#
25
25
26
26
27
- @tl .constexpr_function
27
+ @gluon .constexpr_function
28
28
def get_tmem_32x32b_reg_layout (instr_shape , shape , num_warps ):
29
29
assert len (shape ) == 2 , "expected a 2D tensor"
30
30
assert num_warps in [4 , 8 ], "expected 4 or 8 warps"
@@ -60,15 +60,15 @@ def get_tmem_32x32b_reg_layout(instr_shape, shape, num_warps):
60
60
)
61
61
62
62
63
- @tl .constexpr_function
63
+ @gluon .constexpr_function
64
64
def get_mma_instr_shape (shape , element_ty ):
65
65
m = 128 if shape [0 ] >= 128 else 64
66
66
n = 256 if shape [1 ] >= 256 else shape [1 ]
67
67
k = 256 // element_ty .primitive_bitwidth
68
68
return (m , n , k )
69
69
70
70
71
- @tl .constexpr_function
71
+ @gluon .constexpr_function
72
72
def get_nvmma_layout (shape , element_ty , order = [1 , 0 ], fp4_padded = False ):
73
73
packing_factor = 2 if fp4_padded else 1
74
74
@@ -100,7 +100,7 @@ def get_nvmma_layout(shape, element_ty, order=[1, 0], fp4_padded=False):
100
100
)
101
101
102
102
103
- @tl .constexpr_function
103
+ @gluon .constexpr_function
104
104
def get_mma_reg_layout (shape , num_warps , dtype = gl .float32 ):
105
105
instr_shape = get_mma_instr_shape (shape , dtype )
106
106
return get_tmem_32x32b_reg_layout (instr_shape , shape , num_warps )
@@ -111,7 +111,7 @@ def get_mma_reg_layout(shape, num_warps, dtype=gl.float32):
111
111
# ===-----------------------------------------------------------------------===#
112
112
113
113
114
- @tl .constexpr_function
114
+ @gluon .constexpr_function
115
115
def get_load_size_bytes (desc ):
116
116
size = desc .dtype .primitive_bitwidth // 8
117
117
for dim in desc .block_type .shape :
@@ -385,7 +385,7 @@ def __init__(self, channel, instr_shape, shape):
385
385
def release (self ):
386
386
self .channel .release ()
387
387
388
- @tl .constexpr_function
388
+ @gluon .constexpr_function
389
389
def get_reg_layout (self , num_warps ):
390
390
return get_tmem_32x32b_reg_layout (self .instr_shape , self .shape , num_warps )
391
391
0 commit comments