Skip to content

Commit caac355

Browse files
Merge commit '2a04155bd063630a2b59b0d437d922b12828fbbd'
2 parents 71a23b2 + 2a04155 commit caac355

File tree

12 files changed

+274
-71
lines changed

12 files changed

+274
-71
lines changed

RELEASE.md

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,53 @@
1-
# Release Process
1+
# Releasing Triton
2+
3+
Triton releases provide a stable snapshot of the code base encapsulated into a binary that can easily be consumed through PyPI. Additionally, releases represent points in time when we, as the development team, can signal to the community that certain new features are available, what improvements have been made, and any changes that are coming that may impact them (i.e. breaking changes).
4+
5+
## Release Compatibility Matrix
6+
7+
Following is the Release Compatibility Matrix for Triton releases:
8+
9+
| Triton version | Python version | Manylinux version |
10+
| --- | --- | --- |
11+
| 3.2.0 | >=3.9, <=3.13 | glibc 2.17+ x86-64 |
12+
| 3.1.0 | >=3.8, <=3.12 | glibc 2.17+ x86-64 |
13+
| 3.0.0 | >=3.8, <=3.12 | glibc 2.17+ x86-64 |
14+
| 2.3.1 | >=3.7, <=3.12 | glibc 2.17+ x86-64 |
15+
| 2.3.0 | >=3.7, <=3.12 | glibc 2.17+ x86-64 |
16+
| 2.2.0 | >=3.7, <=3.12 | glibc 2.17+ x86-64 |
17+
| 2.1.0 | >=3.7, <=3.11 | glibc 2.17+ x86-64 |
18+
| 2.0.0 | >=3.6, <=3.11 | glibc 2.17+ x86-64 |
19+
| 1.1.1 | >=3.6, <=3.9 | glibc 2.17+ x86-64 |
20+
| 1.1.0 | >=3.6, <=3.9 | glibc 2.17+ x86-64 |
21+
| 1.0.0 | >=3.6, <=3.9 | glibc 2.17+ x86-64 |
22+
23+
## Release Cadence
24+
25+
Following is the release cadence for year 2024/2025. All future release dates below are tentative. Please note: Patch Releases are optional.
26+
27+
| Minor Version | Release branch cut | Release date | Patch Release date |
28+
| --- | --- | --- | --- |
29+
| 3.5.0 | Sep 2025 | Oct 2025 | --- |
30+
| 3.4.0 | Jun 2025 | Jul 2025 | --- |
31+
| 3.3.0 | Feb/Mar 2025 | Apr 2025 | --- |
32+
| 3.2.0 | Dec 2024 | Jan 2025 | --- |
33+
| 3.1.0 | Jun 2024 | Oct 2024 | --- |
34+
| 3.0.0 | Jun 2024 | Jul 2024 | --- |
35+
| 2.3.0 | Dec 2023 | Apr 2024 | May 2024 |
36+
| 2.2.0 | Dec 2023 | Jan 2024 | --- |
37+
38+
## Release Cherry-Pick Criteria
39+
40+
After branch cut, we approach finalizing the release branch with clear criteria on what cherry picks are allowed in. Note: a cherry pick is a process to land a PR in the release branch after branch cut. These are typically limited to ensure that the team has sufficient time to complete a thorough round of testing on a stable code base.
41+
42+
* Regression fixes - that address functional/performance regression against the most recent release (e.g. 3.2 for 3.3 release)
43+
* Critical fixes - critical fixes for severe issue such as silent incorrectness, backwards compatibility, crashes, deadlocks, (large) memory leaks
44+
* Fixes to new features introduced in the most recent release (e.g. 3.2 for 3.3 release)
45+
* Documentation improvements
46+
* Release branch specific changes (e.g. change version identifiers or CI fixes)
47+
48+
Please note: **No feature work allowed for cherry picks**. All PRs that are considered for cherry-picks need to be merged on trunk, the only exception are Release branch specific changes. An issue is for tracking cherry-picks to the release branch is created after the branch cut. **Only issues that have ‘cherry-picks’ in the issue tracker will be considered for the release.**
49+
50+
# Intel Release Process
251

352
Intel XPU Backend for Triton releases are aligned to the upstream `triton-lang/triton` project and to `PyTorch`. To make a release:
453

lib/Analysis/Utility.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ bool ScanLoweringHelper::isSupported() {
302302
}
303303

304304
unsigned ScanLoweringHelper::getScratchSizeInElems() {
305-
unsigned numWarps = lookupNumWarps(scanOp);
305+
unsigned numWarps = product(getEncoding().getWarpsPerCTA());
306306
unsigned numNonAxisElementsPerWarp =
307307
getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread();
308308
unsigned numElements = numWarps * numNonAxisElementsPerWarp *

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -372,18 +372,7 @@ static std::optional<ttg::SharedEncodingTrait>
372372
getSharedEncoding(Operation *loadOp, bool isTMALoad) {
373373
auto ty = cast<RankedTensorType>(loadOp->getResultTypes()[0]);
374374
auto ctaLayout = ttg::getCTALayout(ty.getEncoding());
375-
auto blockedOrder = ttg::getOrder(ty.getEncoding());
376-
SmallVector<unsigned> order;
377-
if (blockedOrder.size() == 3) {
378-
for (unsigned i = 0; i < blockedOrder.size(); ++i) {
379-
if (blockedOrder[i] == 0)
380-
continue;
381-
order.push_back(blockedOrder[i]);
382-
}
383-
order.push_back(0);
384-
} else {
385-
order = blockedOrder;
386-
}
375+
auto order = ttg::getOrder(ty.getEncoding());
387376

388377
ttg::SharedEncodingTrait localAllocEnc;
389378
if (llvm::any_of(loadOp->getUsers(), [&](Operation *user) {

python/triton/runtime/jit.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,9 @@ def run(self, *args, grid, warmup, **kwargs):
587587
*bound_args.values())
588588
return kernel
589589

590+
def repr(self, _):
591+
return self._fn_name if self._repr is None else self._repr(_)
592+
590593
def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_on_alignment=None, debug=None,
591594
noinline=None, repr=None, launch_metadata=None):
592595
do_not_specialize = do_not_specialize if do_not_specialize else []
@@ -599,7 +602,8 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o
599602
self.do_not_specialize = do_not_specialize
600603
self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
601604
self.starting_line_number = inspect.getsourcelines(fn)[1]
602-
self.repr = lambda _: fn.__name__ if repr is None else repr(_)
605+
self._repr = repr
606+
self._fn_name = fn.__name__
603607
self.launch_metadata = launch_metadata
604608

605609
self.params = []
@@ -613,7 +617,7 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o
613617
src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():]
614618
self._unsafe_update_src(src)
615619
# cache of just-in-time compiled kernels
616-
self.device_caches = defaultdict(lambda: self.create_binder())
620+
self.device_caches = defaultdict(self.create_binder)
617621
self.hash = None
618622

619623
# Map of global variables used by the function and any functions it

test/Analysis/test-allocation.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,4 +615,14 @@ tt.func @call_graph_2(%A : !tt.ptr<f16>, %cond : i1) {
615615
// CHECK-NEXT: size = 1024
616616
}
617617

618+
// CHECK-LABEL: scan_alloc
619+
tt.func @scan_alloc(%x : tensor<8x16xf32, #AL>) {
620+
// CHECK: offset = 0, size = 512
621+
%a = "tt.scan"(%x) <{axis = 0 : i32, reverse = false}>({
622+
^bb0(%arg0: f32, %arg1: f32):
623+
%add = arith.addf %arg0, %arg1 : f32
624+
tt.scan.return %add : f32
625+
}) : (tensor<8x16xf32, #AL>) -> tensor<8x16xf32, #AL>
626+
tt.return
627+
}
618628
}

test/Conversion/amd/buffer_load_store.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
99
// CHECK: %[[offset:.*]] = llvm.select %[[c_mask]]
1010
// CHECK: %[[aux:.*]] = llvm.mlir.constant(3 : i32) : i32
1111
// CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[offset]], {{.*}}, %[[aux]]
12-
%c0 = arith.constant 0 : i32
13-
%ret = amdgpu.buffer_load %arg0[%offset] cacheModifier = cs stride = %c0 : tensor<128xf32, #blocked0>
12+
%ret = amdgpu.buffer_load %arg0[%offset] cacheModifier = cs : tensor<128xf32, #blocked0>
1413
tt.return
1514
}
1615
}

test/TritonGPU/amd/amd-convert-buffer-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
566566
%5 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
567567
%6 = tt.splat %5 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
568568
%7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
569-
// CHECK: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]] stride = %c0_i32
569+
// CHECK: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]]
570570
%8 = tt.atomic_rmw fadd, acq_rel, gpu, %7, %arg1 : (tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked>
571571
tt.return %8 : tensor<1024xf32, #blocked>
572572
}

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 77 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ include "TritonAMDGPUAttrDefs.td"
4141

4242

4343
class TT_AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
44-
Op<TritonAMDGPU_Dialect, mnemonic, !listconcat(traits, [])> {
45-
}
44+
Op<TritonAMDGPU_Dialect, mnemonic, !listconcat(traits, [])>;
4645

4746
//
4847
// Interfaces
@@ -53,8 +52,7 @@ def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;
5352
// ExtractSliceOp
5453
//===----------------------------------------------------------------------===//
5554

56-
def ExtractSliceOp
57-
: TT_AMDGPU_Op<"extract_slice", [Pure]> {
55+
def ExtractSliceOp : TT_AMDGPU_Op<"extract_slice", [Pure]> {
5856
let summary = "extract slice operation";
5957
let description = [{
6058
The "extract_slice" operation enables extracting a slice of a tensor in
@@ -92,8 +90,10 @@ def ExtractSliceOp
9290
size of the slice is determined by the result type.
9391
}];
9492

95-
let arguments = (ins AnyRankedTensor:$source,
96-
DenseI64ArrayAttr:$static_offsets);
93+
let arguments = (ins
94+
AnyRankedTensor:$source,
95+
DenseI64ArrayAttr:$static_offsets
96+
);
9797
let results = (outs AnyRankedTensor:$result);
9898

9999
let builders = [
@@ -117,6 +117,10 @@ def ExtractSliceOp
117117
let hasVerifier = 1;
118118
}
119119

120+
//===----------------------------------------------------------------------===//
121+
// InstructionSchedHint
122+
//===----------------------------------------------------------------------===//
123+
120124
def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
121125
let summary = "A placeholder op for instruction scheduling hints within a basic block";
122126
let description = [{
@@ -156,8 +160,11 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
156160
let assemblyFormat = [{ attr-dict }];
157161
}
158162

159-
def CondBarrierOp : TT_AMDGPU_Op<"cond_barrier">,
160-
Arguments<(ins I1:$pred)> {
163+
//===----------------------------------------------------------------------===//
164+
// CondBarrierOp
165+
//===----------------------------------------------------------------------===//
166+
167+
def CondBarrierOp : TT_AMDGPU_Op<"cond_barrier"> {
161168
let summary = "Conditionally set barriers to synchronize partial threads in a block";
162169

163170
let description = [{
@@ -170,22 +177,25 @@ def CondBarrierOp : TT_AMDGPU_Op<"cond_barrier">,
170177
NB. This doesn't set any memory fence.
171178
}];
172179

180+
let arguments = (ins I1:$pred);
181+
173182
let assemblyFormat = "$pred attr-dict";
174183
}
175184

176-
//
177-
// AMD Buffer operations.
178-
//
185+
//===----------------------------------------------------------------------===//
186+
// BufferLoadOp
187+
//===----------------------------------------------------------------------===//
188+
179189
def BufferLoadOp : TT_AMDGPU_Op<"buffer_load", [
180190
SameLoadStoreOperandsAndResultEncoding,
181191
AttrSizedOperandSegments,
182192
MemoryEffects<[MemRead<GlobalMemory>]>,
183193
TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">,
184194
TypesMatchWith<"result and offsets have the same shape", "result", "offsets", "getI32SameShape($_self)">,
185195
TypesMatchWith<"result and mask have the same shape", "result", "mask", "getI1SameShape($_self)",
186-
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
196+
"(cast<BufferLoadOp>($_op).getMask() == nullptr) || std::equal_to<>()">,
187197
TypesMatchWith<"result and other have the same type", "result", "other", "$_self",
188-
"($_op.getOperands().size() <= 4) || std::equal_to<>()">,
198+
"(cast<BufferLoadOp>($_op).getOther() == nullptr) || std::equal_to<>()">,
189199
]>{
190200
let summary = "Load from a scalar base pointer and a tensor offset";
191201
let description = [{
@@ -201,11 +211,10 @@ def BufferLoadOp : TT_AMDGPU_Op<"buffer_load", [
201211
when it converts to the buffer ops because it is important for optimizing
202212
the cache memory access.
203213
}];
204-
let arguments = (
205-
ins
214+
let arguments = (ins
206215
TT_Ptr:$ptr,
207216
I32Tensor:$offsets,
208-
I32:$stride,
217+
Optional<I32>:$stride,
209218
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
210219
Optional<TT_BoolTensor>:$mask,
211220
Optional<TT_Tensor>:$other
@@ -215,24 +224,29 @@ def BufferLoadOp : TT_AMDGPU_Op<"buffer_load", [
215224
let assemblyFormat = [{
216225
$ptr `[` $offsets `]` (`,` $mask^)? (`,` $other^)?
217226
oilist(`cacheModifier` `=` $cache)
218-
`stride` `=` $stride
227+
(`stride` `=` $stride^)?
219228
attr-dict `:` type($result)
220229
}];
221230
}
222231

232+
//===----------------------------------------------------------------------===//
233+
// BufferAtomicRMWOp
234+
//===----------------------------------------------------------------------===//
235+
223236
def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [
237+
AttrSizedOperandSegments,
224238
SameLoadStoreOperandsAndResultEncoding,
225239
MemoryEffects<[MemRead<GlobalMemory>]>,
226240
MemoryEffects<[MemWrite<GlobalMemory>]>,
227241
TypesMatchWith<"result element type matches the value type", "result", "value", "$_self">,
228242
TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">,
229243
TypesMatchWith<"result and offsets have the same shape", "result", "offsets", "getI32SameShape($_self)">,
230244
TypesMatchWith<"result and mask have the same shape", "result", "mask", "getI1SameShape($_self)",
231-
"($_op.getOperands().size() <= 4) || std::equal_to<>()">,
245+
"(cast<BufferAtomicRMWOp>($_op).getMask() == nullptr) || std::equal_to<>()">,
232246
TypesMatchWith<"value element type matches the pointed type of ptr", "value", "ptr", "getPointerTypeToElement($_self)">,
233247
TypesMatchWith<"value and offsets have the same shape", "value", "offsets", "getI32SameShape($_self)">,
234248
TypesMatchWith<"value and mask have the same shape", "value", "mask", "getI1SameShape($_self)",
235-
"($_op.getOperands().size() <= 4) || std::equal_to<>()">,
249+
"(cast<BufferAtomicRMWOp>($_op).getMask() == nullptr) || std::equal_to<>()">,
236250
]>{
237251
let summary = "Atomic RMW op which reads, modifies, and writes to a scalar base pointer and a tensor offset";
238252
let description = [{
@@ -246,13 +260,12 @@ def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [
246260
the address difference between the first elements of each row in bytes. Compiler tries to obtain the `stride`
247261
when it converts to the buffer ops because it is important for optimizing the cache memory access.
248262
}];
249-
let arguments = (
250-
ins
263+
let arguments = (ins
251264
TT_AtomicRMWAttr:$atomic_rmw_op,
252265
TT_Ptr:$ptr,
253266
I32Tensor:$offsets,
254267
TT_Tensor:$value,
255-
I32:$stride,
268+
Optional<I32>:$stride,
256269
TT_MemSemanticAttr:$sem,
257270
TT_MemSyncScopeAttr:$scope,
258271
Optional<TT_BoolTensor>:$mask
@@ -261,18 +274,23 @@ def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [
261274

262275
let assemblyFormat = [{
263276
$atomic_rmw_op `,` $sem `,` $scope `,` $value `,` $ptr `[` $offsets `]` (`,` $mask^)?
264-
`stride` `=` $stride
277+
(`stride` `=` $stride^)?
265278
attr-dict `:` type($result)
266279
}];
267280
}
268281

282+
//===----------------------------------------------------------------------===//
283+
// BufferStoreOp
284+
//===----------------------------------------------------------------------===//
285+
269286
def BufferStoreOp : TT_AMDGPU_Op<"buffer_store", [
287+
AttrSizedOperandSegments,
270288
SameLoadStoreOperandsEncoding,
271289
MemoryEffects<[MemWrite<GlobalMemory>]>,
272290
TypesMatchWith<"value element type matches the pointed type of ptr", "value", "ptr", "getPointerTypeToElement($_self)">,
273291
TypesMatchWith<"value and offsets have the same shape", "value", "offsets", "getI32SameShape($_self)">,
274292
TypesMatchWith<"value and mask have the same shape", "value", "mask", "getI1SameShape($_self)",
275-
"($_op.getOperands().size() <= 4) || std::equal_to<>()">,
293+
"(cast<BufferStoreOp>($_op).getMask() == nullptr) || std::equal_to<>()">,
276294
]>{
277295
let summary = "Store into scalar base pointer and a tensor offset";
278296
let description = [{
@@ -288,22 +306,53 @@ def BufferStoreOp : TT_AMDGPU_Op<"buffer_store", [
288306
when it converts to the buffer ops because it is important for optimizing
289307
the cache memory access.
290308
}];
291-
let arguments = (
292-
ins
309+
let arguments = (ins
293310
TT_Tensor:$value,
294311
TT_Ptr:$ptr,
295312
I32Tensor:$offsets,
296-
I32:$stride,
313+
Optional<I32>:$stride,
297314
DefaultValuedAttr<TT_CacheModifierAttr, "mlir::triton::CacheModifier::NONE">:$cache,
298315
Optional<TT_BoolTensor>:$mask
299316
);
300317

301318
let assemblyFormat = [{
302319
$value `,` $ptr `[` $offsets `]` (`,` $mask^)?
303320
oilist(`cacheModifier` `=` $cache)
304-
`stride` `=` $stride
321+
(`stride` `=` $stride^)?
305322
attr-dict `:` type($value)
306323
}];
307324
}
308325

326+
//===----------------------------------------------------------------------===//
327+
// UpcastMXFPOp
328+
//===----------------------------------------------------------------------===//
329+
330+
def TTG_UpcastMXFPOp : TT_AMDGPU_Op<"upcast_mxfp", [Pure]> {
331+
let summary = "Convert an mxfp tensor to bf16/fp16";
332+
333+
let hasVerifier = 1;
334+
335+
let description = [{
336+
Compute the bf16 encoded in the given mxfp number as per
337+
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
338+
}];
339+
let arguments = (
340+
ins
341+
TT_Tensor:$src,
342+
TT_Tensor:$scale,
343+
TT_ScaleDotElemTypeAttr:$fp_type,
344+
BoolAttr:$fastMath
345+
);
346+
let results = (outs TT_Tensor:$result);
347+
348+
let assemblyFormat = [{
349+
$src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result)
350+
}];
351+
352+
let extraClassDeclaration = [{
353+
static RankedTensorType deduceOutputType(
354+
TypedValue<RankedTensorType> inputTensor, ScaleDotElemType inputElemType, Type outputElemType);
355+
}];
356+
}
357+
309358
#endif

0 commit comments

Comments
 (0)