Skip to content

Commit 94afdc9

Browse files
authored
[Benchmarking] Update FA/Gluon to a new OSS version
Differential Revision: D81245001 Pull Request resolved: #370
1 parent 53ba426 commit 94afdc9

File tree

2 files changed

+430
-237
lines changed

2 files changed

+430
-237
lines changed

tritonbench/kernels/gluon_attention_forward.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
# ===-----------------------------------------------------------------------===#
2525

2626

27-
@tl.constexpr_function
27+
@gluon.constexpr_function
2828
def get_tmem_32x32b_reg_layout(instr_shape, shape, num_warps):
2929
assert len(shape) == 2, "expected a 2D tensor"
3030
assert num_warps in [4, 8], "expected 4 or 8 warps"
@@ -60,15 +60,15 @@ def get_tmem_32x32b_reg_layout(instr_shape, shape, num_warps):
6060
)
6161

6262

63-
@tl.constexpr_function
63+
@gluon.constexpr_function
6464
def get_mma_instr_shape(shape, element_ty):
6565
m = 128 if shape[0] >= 128 else 64
6666
n = 256 if shape[1] >= 256 else shape[1]
6767
k = 256 // element_ty.primitive_bitwidth
6868
return (m, n, k)
6969

7070

71-
@tl.constexpr_function
71+
@gluon.constexpr_function
7272
def get_nvmma_layout(shape, element_ty, order=[1, 0], fp4_padded=False):
7373
packing_factor = 2 if fp4_padded else 1
7474

@@ -100,7 +100,7 @@ def get_nvmma_layout(shape, element_ty, order=[1, 0], fp4_padded=False):
100100
)
101101

102102

103-
@tl.constexpr_function
103+
@gluon.constexpr_function
104104
def get_mma_reg_layout(shape, num_warps, dtype=gl.float32):
105105
instr_shape = get_mma_instr_shape(shape, dtype)
106106
return get_tmem_32x32b_reg_layout(instr_shape, shape, num_warps)
@@ -111,7 +111,7 @@ def get_mma_reg_layout(shape, num_warps, dtype=gl.float32):
111111
# ===-----------------------------------------------------------------------===#
112112

113113

114-
@tl.constexpr_function
114+
@gluon.constexpr_function
115115
def get_load_size_bytes(desc):
116116
size = desc.dtype.primitive_bitwidth // 8
117117
for dim in desc.block_type.shape:
@@ -385,7 +385,7 @@ def __init__(self, channel, instr_shape, shape):
385385
def release(self):
386386
self.channel.release()
387387

388-
@tl.constexpr_function
388+
@gluon.constexpr_function
389389
def get_reg_layout(self, num_warps):
390390
return get_tmem_32x32b_reg_layout(self.instr_shape, self.shape, num_warps)
391391

0 commit comments

Comments
 (0)