|
1 | 1 | import torch
|
2 | 2 | import triton
|
3 |
| -import triton.language as tl |
4 | 3 | import pytest
|
5 | 4 | import itertools
|
6 | 5 |
|
|
25 | 24 | # ===-----------------------------------------------------------------------===#
|
26 | 25 |
|
27 | 26 |
|
28 |
| -@tl.constexpr_function |
| 27 | +@gl.constexpr_function |
29 | 28 | def get_tmem_32x32b_reg_layout(instr_shape, shape, num_warps):
|
30 | 29 | assert len(shape) == 2, "expected a 2D tensor"
|
31 | 30 | assert num_warps in [4, 8], "expected 4 or 8 warps"
|
@@ -61,15 +60,15 @@ def get_tmem_32x32b_reg_layout(instr_shape, shape, num_warps):
|
61 | 60 | )
|
62 | 61 |
|
63 | 62 |
|
64 |
| -@tl.constexpr_function |
| 63 | +@gl.constexpr_function |
65 | 64 | def get_mma_instr_shape(shape, element_ty):
|
66 | 65 | m = 128 if shape[0] >= 128 else 64
|
67 | 66 | n = 256 if shape[1] >= 256 else shape[1]
|
68 | 67 | k = 256 // element_ty.primitive_bitwidth
|
69 | 68 | return (m, n, k)
|
70 | 69 |
|
71 | 70 |
|
72 |
| -@tl.constexpr_function |
| 71 | +@gl.constexpr_function |
73 | 72 | def get_nvmma_layout(shape, element_ty, order=[1, 0], fp4_padded=False):
|
74 | 73 | packing_factor = 2 if fp4_padded else 1
|
75 | 74 |
|
@@ -99,7 +98,7 @@ def get_nvmma_layout(shape, element_ty, order=[1, 0], fp4_padded=False):
|
99 | 98 | )
|
100 | 99 |
|
101 | 100 |
|
102 |
| -@tl.constexpr_function |
| 101 | +@gl.constexpr_function |
103 | 102 | def get_mma_reg_layout(shape, num_warps, dtype=gl.float32):
|
104 | 103 | instr_shape = get_mma_instr_shape(shape, dtype)
|
105 | 104 | return get_tmem_32x32b_reg_layout(instr_shape, shape, num_warps)
|
@@ -133,7 +132,7 @@ def alloc(shape: gl.constexpr, dtype: gl.constexpr, layout: gl.constexpr, num_bu
|
133 | 132 | mem = alloc_fn(dtype, [num_buffers] + shape, layout)
|
134 | 133 | ready_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
|
135 | 134 | empty_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
|
136 |
| - for i in tl.static_range(num_buffers): |
| 135 | + for i in gl.static_range(num_buffers): |
137 | 136 | mbarrier.init(ready_bars.index(i), count=1)
|
138 | 137 | mbarrier.init(empty_bars.index(i), count=num_consumers)
|
139 | 138 | mbarrier.arrive(empty_bars.index(i), count=num_consumers)
|
@@ -179,7 +178,7 @@ def create_consumer(self):
|
179 | 178 | def release(self):
|
180 | 179 | if isinstance(self.mem, gl.shared_memory_descriptor):
|
181 | 180 | self.mem._keep_alive()
|
182 |
| - for i in tl.static_range(self.num_buffers): |
| 181 | + for i in gl.static_range(self.num_buffers): |
183 | 182 | mbarrier.invalidate(self.ready_bars.index(i))
|
184 | 183 | mbarrier.invalidate(self.empty_bars.index(i))
|
185 | 184 |
|
@@ -847,7 +846,7 @@ def _attn_fwd_correction_rescale(config, s_tmem, corr_consumer, o_consumer):
|
847 | 846 | mbarrier.arrive(corr_bar, count=1)
|
848 | 847 | alpha = gl.convert_layout(alpha.reshape([config.SPLIT_M]), alpha_layout)
|
849 | 848 |
|
850 |
| - for i in tl.static_range(config.SPLIT_D_FACTOR): |
| 849 | + for i in gl.static_range(config.SPLIT_D_FACTOR): |
851 | 850 | o_ref = o_tmem.slice(i * config.SPLIT_D, config.SPLIT_D)
|
852 | 851 | o = o_ref.load(config.o_splitn_layout)
|
853 | 852 | o = _mul_f32x2(o, alpha[:, None])
|
@@ -882,7 +881,7 @@ def _attn_fwd_correction_epilogue(config, prog, s_tmem, M, corr_consumer, epi_pr
|
882 | 881 | SPLIT_N: gl.constexpr = o_smem.type.shape[1] // SPLIT_N_FACTOR
|
883 | 882 |
|
884 | 883 | scale = 1 / l_i
|
885 |
| - for i in tl.static_range(SPLIT_N_FACTOR): |
| 884 | + for i in gl.static_range(SPLIT_N_FACTOR): |
886 | 885 | o_ref = o_tmem.slice(i * SPLIT_N, SPLIT_N)
|
887 | 886 | o = o_ref.load(config.o_splitn_layout)
|
888 | 887 | o = _mul_f32x2(o, scale[:, None])
|
@@ -992,7 +991,7 @@ def attention_kernel( #
|
992 | 991 | def torch_dtype_to_triton(dtype):
|
993 | 992 | if dtype == torch.float8_e5m2:
|
994 | 993 | return gl.float8e5
|
995 |
| - return getattr(tl, str(dtype).split('.')[1]) |
| 994 | + return getattr(gl, str(dtype).split('.')[1]) |
996 | 995 |
|
997 | 996 |
|
998 | 997 | def make_tensor_desc(x, shape, strides, block_shape):
|
|
0 commit comments