Skip to content

Commit 12510aa

Browse files
authored
Merge branch 'main' into tkuczynski/enable_test_small_batch_matmul
2 parents 7e59bc5 + 9290e9a commit 12510aa

File tree

137 files changed

+4362
-1953
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

137 files changed

+4362
-1953
lines changed

.github/WINDOWS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ If you do not have a system Python installed at this step, you can install one w
5656
For example:
5757

5858
```
59-
choco install python --version=3.9.13
59+
choco install python --version=3.10.11
6060
```
6161

6262
### Git

.github/workflows/build-test-python.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ jobs:
5858
id: matrix
5959
run: |
6060
if [[ -n "${{ inputs.runner_label }}" ]]; then
61-
matrix='{"python": ["3.9", "3.10", "3.11", "3.12", "3.13"]}'
61+
matrix='{"python": ["3.10", "3.11", "3.12", "3.13"]}'
6262
else
63-
matrix='{"python": ["3.9", "3.10", "3.11", "3.12", "3.13"], "driver": ["rolling", "lts"]}'
63+
matrix='{"python": ["3.10", "3.11", "3.12", "3.13"], "driver": ["rolling", "lts"]}'
6464
fi
6565
echo "matrix=$matrix" | tee -a $GITHUB_OUTPUT
6666

.github/workflows/nightly-wheels.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ jobs:
3232
strategy:
3333
matrix:
3434
python:
35-
- "3.9"
3635
- "3.10"
3736
- "3.11"
3837
- "3.12"

.github/workflows/try-latest-pytorch.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ jobs:
6868
- name: Matrix
6969
id: matrix
7070
run: |
71-
integration_matrix='{"python": ["3.9", "3.10", "3.11", "3.12"], "driver": ["rolling", "lts"]}'
71+
integration_matrix='{"python": ["3.10", "3.11", "3.12"], "driver": ["rolling", "lts"]}'
7272
7373
echo "integration_matrix=$integration_matrix" | tee -a $GITHUB_OUTPUT
7474
e2e_matrix='{
@@ -97,7 +97,7 @@ jobs:
9797
inductor/test_max_autotune.py
9898
inductor/test_compile_subprocess.py
9999
runner_label: ${{ inputs.runner_label }}
100-
python_version: "3.9"
100+
python_version: "3.10"
101101

102102
integration-tests:
103103
name: Integration tests

.github/workflows/wheels-pytorch.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ jobs:
2121
strategy:
2222
matrix:
2323
python:
24-
- "3.9"
2524
- "3.10"
2625
- "3.11"
2726
- "3.12"

.github/workflows/wheels-triton.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ jobs:
1616
strategy:
1717
matrix:
1818
python:
19-
- "3.9"
2019
- "3.10"
2120
- "3.11"
2221
- "3.12"

README.md

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
1-
<div align="center">
2-
<img src="https://lh5.googleusercontent.com/wzQKEsTFkrgNQO9JjhGH5wFvslJr1saLtLaJ_a6Fp_gNENpvt3VG7BmztwngU9hFJaU4CPwGiw1opQtDvTkLrxWRbO_a12Q-pdESWHgtmheIHcPbOL5ZMC4TSiJVe5ty1w=w3517" alt="Triton logo">
3-
</div>
41

52
| **`Documentation`** | **`Nightly Wheels`** |
63
|-------------------- | -------------------- |
74
| [![Documentation](https://github.com/triton-lang/triton/actions/workflows/documentation.yml/badge.svg)](https://triton-lang.org/) | [![Wheels](https://github.com/triton-lang/triton/actions/workflows/wheels.yml/badge.svg)](https://github.com/triton-lang/triton/actions/workflows/wheels.yml) |
85

9-
# Conference Registration
6+
# Triton Conference 2025
7+
8+
![Triton Registration Banner](https://github.com/user-attachments/assets/b4b6972a-857c-417f-bf2c-f16f38a358c0)
9+
10+
### Registration
1011

1112
The 3rd Triton conference is scheduled to take place on October 21, 2025. Click [here](https://tritonconference.eventbuilder.com/TritonDeveloperConference) to register!
1213

14+
### Poster Submission
15+
16+
We invite members of the Triton community who are attending the Triton Developer Conference to present posters about their Triton-related technical work.
17+
18+
Please submit basic information of your poster, including author information and abstract using this [form](https://forms.gle/QfgTF8o1CWNENAnA7).
19+
20+
**Important Dates**
21+
- Submission: 10/1/2025
22+
- Author notification: 10/7/2025
23+
- Final version (PDF): 10/14/2025
1324

1425
# Triton
1526

@@ -27,7 +38,7 @@ You can install the latest stable release of Triton from pip:
2738
pip install triton
2839
```
2940

30-
Binary wheels are available for CPython 3.9-3.13.
41+
Binary wheels are available for CPython 3.10-3.13.
3142

3243
# Install from source
3344

benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py

Lines changed: 55 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,19 @@ def _attn_fwd_with_block_pointers(Q, K, V, sm_scale, M, Out, #
154154
# epilogue
155155
m_i += tl.math.log2(l_i)
156156
acc = acc / l_i[:, None]
157+
if N_CTX <= 512:
158+
off_hz = off_z + off_h * H
159+
else:
160+
off_hz = off_z * H + off_h
161+
M_block_ptr = tl.make_block_ptr(
162+
base=M + off_hz * N_CTX,
163+
shape=[N_CTX],
164+
strides=[1],
165+
offsets=[start_m * BLOCK_M],
166+
block_shape=[BLOCK_M],
167+
order=[0],
168+
)
169+
tl.store(M_block_ptr, m_i)
157170
tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0, 1))
158171

159172

@@ -220,7 +233,7 @@ def _attn_bwd_dkdv(dk, dv, #
220233
if MASK:
221234
mask = (offs_m[None, :] >= offs_n[:, None])
222235
pT = tl.where(mask, pT, 0.0)
223-
do = tl.load(do_ptrs).to(tl.float16)
236+
do = tl.load(do_ptrs)
224237
# Compute dV.
225238
ppT = pT
226239
ppT = ppT.to(tl.float16)
@@ -275,7 +288,7 @@ def _attn_bwd_dq(dq, q, K, V, #
275288
mask = (offs_m[:, None] >= offs_n[None, :])
276289
p = tl.where(mask, p, 0.0)
277290
# Compute dP and dS.
278-
dp = tl.dot(do.to(tl.float16), vT).to(tl.float32)
291+
dp = tl.dot(do, vT).to(tl.float32)
279292
ds = p * (dp - Di[:, None])
280293
ds = ds.to(tl.float16)
281294
# Compute dQ.
@@ -423,12 +436,12 @@ class _attention(torch.autograd.Function):
423436
attn_fwd: Callable = None
424437

425438
@staticmethod
426-
def forward(ctx, q, k, v, causal, sm_scale, dq, dk, dv, delta):
439+
def forward(ctx, q, k, v, causal, sm_scale):
427440
# shape constraints
428441
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
429442
assert Lq == Lk and Lk == Lv
430443
assert Lk in {16, 32, 64, 128}
431-
o = torch.empty_like(q, dtype=torch.float32)
444+
o = torch.empty_like(q)
432445
BLOCK_M = 128
433446
BLOCK_N = 64
434447
num_stages = 3
@@ -473,8 +486,7 @@ def forward(ctx, q, k, v, causal, sm_scale, dq, dk, dv, delta):
473486
advanced_path=True, #
474487
)
475488

476-
ctx.save_for_backward(q, k, v, o, M, dq, dk, dv, delta)
477-
ctx.grid = grid
489+
ctx.save_for_backward(q, k, v, o, M)
478490
ctx.sm_scale = sm_scale
479491
ctx.HEAD_DIM = Lk
480492
ctx.causal = causal
@@ -488,9 +500,12 @@ def backward(ctx, do):
488500
with record_function(
489501
'__profile_kernel_of_func_bwd_fa'
490502
) if benchmark_suite.BENCHMARKING_METHOD == 'UPSTREAM_PYTORCH_PROFILER' else contextlib.nullcontext():
491-
q, k, v, o, M, dq, dk, dv, delta = ctx.saved_tensors
503+
q, k, v, o, M = ctx.saved_tensors
492504
assert do.is_contiguous()
493505
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
506+
dq = torch.empty_like(q)
507+
dk = torch.empty_like(k)
508+
dv = torch.empty_like(v)
494509
BATCH, N_HEAD, N_CTX = q.shape[:3]
495510
PRE_BLOCK = 128
496511
NUM_WARPS, NUM_STAGES = 4, 5
@@ -502,6 +517,7 @@ def backward(ctx, do):
502517
PRE_BLOCK = 128
503518
assert N_CTX % PRE_BLOCK == 0
504519
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
520+
delta = torch.empty_like(M)
505521
_attn_bwd_preprocess[pre_grid](
506522
o, do, #
507523
delta, #
@@ -522,7 +538,7 @@ def backward(ctx, do):
522538
num_stages=NUM_STAGES #
523539
)
524540

525-
return dq, dk, dv, None, None, None, None, None, None
541+
return dq, dk, dv, None, None, None, None
526542

527543

528544
attention = _attention.apply
@@ -537,6 +553,9 @@ def get_benchmark(
537553
Returns a Mark object containing a Benchmark object constructed at runtime and parameterized by the provided option values.
538554
The benchmark can then be executed by calling the :code:`.run` method on the return value.
539555
"""
556+
causal_mode = [False, True] if fa_kernel_mode == 'fwd' else [
557+
True
558+
] # The 06 tutorial bwd Non-causal tests do not pass at the moment.
540559

541560
supported_providers = {
542561
'triton': 'Triton',
@@ -556,9 +575,9 @@ def get_benchmark(
556575
x_vals=[[z, h, 16384 // z, dhead, causal, mode]
557576
for z in [1, 2, 4, 8, 16, 32]
558577
for (h, dhead) in [(16, 128), (32, 64)]
559-
for causal in [False, True]
578+
for causal in causal_mode
560579
for mode in [fa_kernel_mode]] #
561-
+ [[4, 48, 1024, 64, causal, mode] for causal in [False, True] for mode in [fa_kernel_mode]],
580+
+ [[4, 48, 1024, 64, causal, mode] for causal in causal_mode for mode in [fa_kernel_mode]],
562581
line_arg='provider',
563582
# argument name whose value corresponds to a different line in the plot
564583
# possible values for `line_arg``
@@ -587,60 +606,44 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
587606
if MODE not in modes:
588607
raise AssertionError(f'Unknown {MODE}, supported modes are {modes}')
589608
dtype = torch.float16
609+
torch.xpu.empty_cache()
590610
q = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
591611
k = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
592612
v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
593613
sm_scale = 0.125
594-
dq, dk, dv, delta = None, None, None, None
595-
if MODE == 'bwd':
596-
sm_scale = 1.3
597-
dq = torch.empty_like(q)
598-
dk = torch.empty_like(k)
599-
dv = torch.empty_like(v)
600-
delta = torch.empty_like(q)
601614
quantiles = [0.5, 0.0, 1.0]
602615
atol = 1e-1 if N_CTX == 16384 else 1e-2
616+
bwd_atol = 1e-1 if N_CTX >= 4096 else 1e-2
603617
# FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
604618
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(
605-
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
606-
if MODE == 'bwd':
607-
torch_o = torch_fn()
608-
torch_do = torch.randn_like(torch_o)
609-
torch_fn = lambda: torch_o.backward(torch_do, retain_graph=True)
610-
611-
if provider == 'onednn':
612-
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
613-
torch_fn,
614-
n_warmup=n_warmup,
615-
n_repeat=10,
616-
quantiles=quantiles,
617-
time_warmup=False,
618-
)
619+
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale)
619620

620-
elif provider == 'triton':
621-
triton_fn = lambda: attention(q, k, v, CAUSAL, sm_scale, dq, dk, dv, delta)
622-
if MODE == 'bwd':
623-
triton_o = triton_fn()
624-
triton_do = torch.randn_like(triton_o)
625-
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
621+
if provider == 'triton':
622+
triton_fn = lambda: attention(q, k, v, CAUSAL, sm_scale)
626623
if MODE == 'fwd':
627624
benchmark_suite.assert_close(triton_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='triton to torch')
628625
else:
629-
benchmark_suite.assert_close(
630-
lambda: triton_o,
631-
lambda: torch_o,
632-
atol=1e-2,
633-
rtol=0,
634-
err_msg='triton to torch',
635-
)
636-
637-
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
638-
triton_fn,
639-
n_warmup=n_warmup,
640-
n_repeat=10,
641-
quantiles=quantiles,
642-
time_warmup=False,
643-
)
626+
dout = torch.randn_like(q)
627+
torch_o = torch_fn()
628+
torch_grads = torch.autograd.grad((torch_o, ), (q, k, v), dout.cpu(), retain_graph=True)
629+
eager_tensors = torch_grads
630+
triton_o = triton_fn()
631+
triton_grads = torch.autograd.grad((triton_o, ), (q, k, v), dout, retain_graph=True)
632+
compiled_tensors = triton_grads
633+
634+
benchmark_suite.assert_close(lambda: torch_o, lambda: triton_o, atol=atol, rtol=1e-3,
635+
err_msg='Error comparing out between triton and torch')
636+
637+
tensor_names = ['grad_query', 'grad_key', 'grad_value']
638+
for eager, compiled, name in zip(eager_tensors, compiled_tensors, tensor_names):
639+
benchmark_suite.assert_close(lambda eager=eager: eager, lambda compiled=compiled: compiled,
640+
atol=bwd_atol, rtol=1e-3,
641+
err_msg=f'Error comparing {name} between triton and torch')
642+
triton_fn = lambda: triton_o.backward(dout, retain_graph=True)
643+
644+
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(triton_fn, n_warmup=n_warmup, n_repeat=10,
645+
quantiles=quantiles, grad_to_none=(q, k, v),
646+
time_warmup=False)
644647

645648
elif provider == 'xetla':
646649
if MODE == 'bwd':

benchmarks/triton_kernels_benchmark/fused_softmax.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ def naive_softmax(x):
4141

4242
@triton.autotune(
4343
configs=[
44-
triton.Config({"threads_per_warp": 32}, num_warps=32),
45-
triton.Config({"threads_per_warp": 32}, num_warps=16),
46-
triton.Config({"threads_per_warp": 32}, num_warps=8),
47-
triton.Config({"threads_per_warp": 32}, num_warps=4),
48-
triton.Config({"threads_per_warp": 16}, num_warps=64),
49-
triton.Config({"threads_per_warp": 16}, num_warps=32),
50-
triton.Config({"threads_per_warp": 16}, num_warps=16),
51-
triton.Config({"threads_per_warp": 16}, num_warps=8),
52-
triton.Config({"threads_per_warp": 16}, num_warps=4),
44+
triton.Config({"warp_size": 32}, num_warps=32),
45+
triton.Config({"warp_size": 32}, num_warps=16),
46+
triton.Config({"warp_size": 32}, num_warps=8),
47+
triton.Config({"warp_size": 32}, num_warps=4),
48+
triton.Config({"warp_size": 16}, num_warps=64),
49+
triton.Config({"warp_size": 16}, num_warps=32),
50+
triton.Config({"warp_size": 16}, num_warps=16),
51+
triton.Config({"warp_size": 16}, num_warps=8),
52+
triton.Config({"warp_size": 16}, num_warps=4),
5353
],
5454
key=["BLOCK_SIZE_X", "BLOCK_SIZE_Y"],
5555
)

bin/RegisterTritonDialects.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
4949
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
50+
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
5051
#include "mlir/InitAllPasses.h"
5152

5253
namespace mlir {
@@ -107,6 +108,8 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
107108
mlir::triton::registerConvertTritonGENToLLVM();
108109
mlir::triton::registerTritonGENToLLVMPasses();
109110
mlir::triton::registerTritonGENToSPIRVPasses();
111+
mlir::LLVM::registerInlinerInterface(registry);
112+
mlir::NVVM::registerInlinerInterface(registry);
110113

111114
// TritonAMDGPUToLLVM passes
112115
mlir::triton::registerAllocateAMDGPUSharedMemory();

0 commit comments

Comments
 (0)