Skip to content

Commit 3320f47

Browse files
authored
Merge branch 'main' into tkuczynski/enable_test_small_batch_matmul
2 parents 12510aa + ceefb53 commit 3320f47

File tree

22 files changed

+1173
-113
lines changed

22 files changed

+1173
-113
lines changed

.github/workflows/build-test-reusable.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ jobs:
196196
suite:
197197
- minicore
198198
- scaled_dot
199+
- gluon
199200
- rest
200201
- tutorial-fa-64
201202
- tutorial-fa-128-fwdfp8
@@ -306,6 +307,11 @@ jobs:
306307
run: |
307308
${{ env.TRITON_TEST_CMD }} --scaled-dot
308309
310+
- name: Run gluon tests
311+
if: matrix.suite == 'gluon' && inputs.driver_version == 'rolling'
312+
run: |
313+
${{ env.TRITON_TEST_CMD }} --gluon
314+
309315
- name: Run interpreter tests
310316
if: matrix.suite == 'rest'
311317
run: |

.github/workflows/build-test-windows.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,13 @@ jobs:
148148
cd ${{ env.NEW_WORKSPACE }}
149149
${{ env.TRITON_TEST_CMD }} --core
150150
151+
- name: Run gluon tests
152+
run: |
153+
.venv\Scripts\activate.ps1
154+
Invoke-BatchFile "C:\Program Files (x86)\Intel\oneAPI\setvars.bat"
155+
cd ${{ env.NEW_WORKSPACE }}
156+
${{ env.TRITON_TEST_CMD }} --gluon
157+
151158
- name: Run triton kernels tests
152159
run: |
153160
.venv\Scripts\activate.ps1

benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -213,17 +213,18 @@ def _attn_bwd_dkdv(dk, dv, #
213213
# Filled in by the wrapper.
214214
start_n, start_m, num_steps, #
215215
MASK: tl.constexpr):
216-
offs_m = start_m + tl.arange(0, BLOCK_M1)
217216
offs_n = start_n + tl.arange(0, BLOCK_N1)
218-
offs_k = tl.arange(0, HEAD_DIM)
219-
qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
220-
do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
217+
qT_desc = tl.make_tensor_descriptor(Q, shape=[HEAD_DIM, N_CTX], strides=[stride_d, stride_tok],
218+
block_shape=[HEAD_DIM, BLOCK_M1])
219+
220+
do_desc = tl.make_tensor_descriptor(DO, shape=[N_CTX, HEAD_DIM], strides=[stride_tok, stride_d],
221+
block_shape=[BLOCK_M1, HEAD_DIM])
221222
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
222223
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
223224
curr_m = start_m
224225
step_m = BLOCK_M1
225226
for blk_idx in range(num_steps):
226-
qT = tl.load(qT_ptrs)
227+
qT = qT_desc.load([0, start_m + blk_idx * step_m])
227228
# Load m before computing qk to reduce pipeline stall.
228229
offs_m = curr_m + tl.arange(0, BLOCK_M1)
229230
m = tl.load(M + offs_m)
@@ -233,7 +234,7 @@ def _attn_bwd_dkdv(dk, dv, #
233234
if MASK:
234235
mask = (offs_m[None, :] >= offs_n[:, None])
235236
pT = tl.where(mask, pT, 0.0)
236-
do = tl.load(do_ptrs)
237+
do = do_desc.load([start_m + blk_idx * step_m, 0])
237238
# Compute dV.
238239
ppT = pT
239240
ppT = ppT.to(tl.float16)
@@ -247,8 +248,6 @@ def _attn_bwd_dkdv(dk, dv, #
247248
dk += tl.dot(dsT, tl.trans(qT))
248249
# Increment pointers.
249250
curr_m += step_m
250-
qT_ptrs += step_m * stride_tok
251-
do_ptrs += step_m * stride_tok
252251
return dk, dv
253252

254253

@@ -267,19 +266,20 @@ def _attn_bwd_dq(dq, q, K, V, #
267266
start_m, start_n, num_steps, #
268267
MASK: tl.constexpr):
269268
offs_m = start_m + tl.arange(0, BLOCK_M2)
270-
offs_n = start_n + tl.arange(0, BLOCK_N2)
271-
offs_k = tl.arange(0, HEAD_DIM)
272-
kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
273-
vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
269+
kT_desc = tl.make_tensor_descriptor(K, shape=[HEAD_DIM, N_CTX], strides=[stride_d, stride_tok],
270+
block_shape=[HEAD_DIM, BLOCK_N2])
271+
272+
vT_desc = tl.make_tensor_descriptor(V, shape=[HEAD_DIM, N_CTX], strides=[stride_d, stride_tok],
273+
block_shape=[HEAD_DIM, BLOCK_N2])
274274
# D (= delta) is pre-divided by ds_scale.
275275
Di = tl.load(D + offs_m)
276276
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
277277
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
278278
curr_n = start_n
279279
step_n = BLOCK_N2
280280
for blk_idx in range(num_steps):
281-
kT = tl.load(kT_ptrs)
282-
vT = tl.load(vT_ptrs)
281+
kT = kT_desc.load([0, start_n + blk_idx * step_n])
282+
vT = vT_desc.load([0, start_n + blk_idx * step_n])
283283
qk = tl.dot(q, kT)
284284
p = tl.math.exp2(qk - m)
285285
# Autoregressive masking.
@@ -296,8 +296,6 @@ def _attn_bwd_dq(dq, q, K, V, #
296296
dq += tl.dot(ds, tl.trans(kT))
297297
# Increment pointers.
298298
curr_n += step_n
299-
kT_ptrs += step_n * stride_tok
300-
vT_ptrs += step_n * stride_tok
301299
return dq
302300

303301

@@ -508,7 +506,7 @@ def backward(ctx, do):
508506
dv = torch.empty_like(v)
509507
BATCH, N_HEAD, N_CTX = q.shape[:3]
510508
PRE_BLOCK = 128
511-
NUM_WARPS, NUM_STAGES = 4, 5
509+
NUM_WARPS, NUM_STAGES = 16, 3
512510
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
513511
BLK_SLICE_FACTOR = 2
514512
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)

python/test/gluon/test_consan.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def async_tma_kernel(input_desc, XBLOCK: ttgl.constexpr, FAILURE: ttgl.constexpr
8585
tma.store_wait(0)
8686

8787

88-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
88+
@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
8989
@pytest.mark.parametrize("FAILURE", [True, False])
9090
def test_async_tma_kernel(FAILURE, device, run_wrapper):
9191
if run_wrapper:
@@ -141,7 +141,7 @@ def tma_interleave_kernel(input_desc, XBLOCK: ttgl.constexpr, FAILURE: ttgl.cons
141141
tma.store_wait(0)
142142

143143

144-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
144+
@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
145145
@pytest.mark.parametrize("FAILURE", [True, False])
146146
def test_tma_interleave_kernel(FAILURE, device, run_wrapper):
147147
if run_wrapper:
@@ -190,7 +190,7 @@ def async_copy_kernel(input, XBLOCK: ttgl.constexpr, FAILURE: ttgl.constexpr):
190190
ampere.async_copy.wait_group(0)
191191

192192

193-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires ampere or newer")
193+
@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires ampere or newer")
194194
@pytest.mark.parametrize("FAILURE", [True, False])
195195
def test_async_copy(FAILURE, device, run_wrapper):
196196
if run_wrapper:
@@ -252,7 +252,7 @@ def tcgen5_mma_kernel(input_desc, XBLOCK: ttgl.constexpr, FAILURE: ttgl.constexp
252252
mbarrier.invalidate(bar.index(1))
253253

254254

255-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell or newer")
255+
@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell or newer")
256256
@pytest.mark.parametrize("FAILURE", [True, False])
257257
@pytest.mark.parametrize("MEM_ACCESS_KIND", ["tma_cp", "local_store", "tmem_load", "tmem_store"])
258258
def test_tcgen5_mma(FAILURE, MEM_ACCESS_KIND, device, run_wrapper):
@@ -305,7 +305,7 @@ def warpgroup_mma_kernel(input, XBLOCK: ttgl.constexpr, FAILURE: ttgl.constexpr)
305305
smemA.store(ttgl.full([XBLOCK, XBLOCK], 42, ttgl.float16, blocked_layout))
306306

307307

308-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, reason="Requires hopper")
308+
@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, reason="Requires hopper")
309309
@pytest.mark.parametrize("FAILURE", [True, False])
310310
def test_warpgroup_mma(FAILURE, device, run_wrapper):
311311
if run_wrapper:
@@ -353,7 +353,7 @@ def warpgroup_mma_kernel2(input, XBLOCK: ttgl.constexpr, FAILURE: ttgl.constexpr
353353
smemA.store(ttgl.full([XBLOCK, XBLOCK], 42, ttgl.float16, blocked_layout))
354354

355355

356-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, reason="Requires hopper")
356+
@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, reason="Requires hopper")
357357
@pytest.mark.parametrize("FAILURE", [True, False])
358358
def test_warpgroup_mma2(FAILURE, device, run_wrapper):
359359
if run_wrapper:
@@ -406,7 +406,7 @@ def tcgen5_mma_multibar_kernel(input_desc, XBLOCK: ttgl.constexpr, BUF_IDX: ttgl
406406
mbarrier.invalidate(bar.index(i))
407407

408408

409-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell or newer")
409+
@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell or newer")
410410
@pytest.mark.parametrize("BUF_IDX", [0, 1])
411411
@pytest.mark.parametrize("BAR_IDX", [0, 1, 2, 3])
412412
def test_tcgen5_mma_multibar(BUF_IDX, BAR_IDX, device, run_wrapper):
@@ -529,7 +529,7 @@ def multibuffered_loop_tma_kernel(input_desc, XBLOCK: ttgl.constexpr, FAILURE: t
529529
mbarrier.invalidate(barMMA.index(i))
530530

531531

532-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell or newer")
532+
@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell or newer")
533533
@pytest.mark.parametrize("FAILURE", [True, False])
534534
def test_multibuffered_loop(FAILURE, device, run_wrapper):
535535
if run_wrapper:
@@ -611,7 +611,7 @@ def multibuffered_loop_wgmma_kernel(input_desc, XBLOCK: ttgl.constexpr, FAILURE:
611611
mbarrier.invalidate(barLoadB.index(i))
612612

613613

614-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, reason="Requires hopper")
614+
@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, reason="Requires hopper")
615615
@pytest.mark.parametrize("FAILURE", [True, False])
616616
def test_multibuffered_wgmma_loop(FAILURE, device, run_wrapper):
617617
if run_wrapper:

0 commit comments

Comments
 (0)