Skip to content

Commit 904f124

Browse files
authored
Merge OpenAI commit 01d3c87 (#5102)
This PR change the Triton base from 8b792c8 to 01d3c87 (Sep 9). Pass rate: 98.11%
2 parents 53bb00f + c1d3d66 commit 904f124

File tree

18 files changed

+364
-308
lines changed

18 files changed

+364
-308
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
|-------------------- | -------------------- |
77
| [![Documentation](https://github.com/triton-lang/triton/actions/workflows/documentation.yml/badge.svg)](https://triton-lang.org/) | [![Wheels](https://github.com/triton-lang/triton/actions/workflows/wheels.yml/badge.svg)](https://github.com/triton-lang/triton/actions/workflows/wheels.yml) |
88

9+
# Conference Registration
10+
11+
The 3rd Triton conference is scheduled to take place on October 21, 2025. Click [here](https://tritonconference.eventbuilder.com/TritonDeveloperConference) to register!
12+
13+
914
# Triton
1015

1116
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs.

lib/Target/LLVMIR/LLVMDIScope.cpp

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ FileLineColLoc extractFileLoc(Location loc) {
3030
return extractFileLoc(opaqueLoc.getFallbackLocation());
3131
if (auto fusedLoc = dyn_cast<FusedLoc>(loc))
3232
return extractFileLoc(fusedLoc.getLocations().front());
33-
if (auto callerLoc = dyn_cast<CallSiteLoc>(loc))
34-
return extractFileLoc(callerLoc.getCaller());
33+
// Prefer the innermost callee for callsite locations.
34+
if (auto csLoc = dyn_cast<CallSiteLoc>(loc))
35+
return extractFileLoc(csLoc.getCallee());
3536
StringAttr unknownFile = mlir::StringAttr::get(loc.getContext(), "<unknown>");
3637
return mlir::FileLineColLoc::get(unknownFile, 0, 0);
3738
}
@@ -109,39 +110,39 @@ struct LLVMDIScopePass : public impl::LLVMDIScopeBase<LLVMDIScopePass> {
109110
funcOp->setLoc(FusedLoc::get(context, {loc}, subprogramAttr));
110111
}
111112

112-
// Get a nested loc for inlined functions
113-
Location getNestedLoc(Operation *op, LLVM::DIScopeAttr scopeAttr,
114-
Location calleeLoc) {
115-
auto calleeFileName = extractFileLoc(calleeLoc).getFilename();
116-
auto context = op->getContext();
117-
LLVM::DIFileAttr calleeFileAttr = LLVM::DIFileAttr::get(
118-
context, llvm::sys::path::filename(calleeFileName),
119-
llvm::sys::path::parent_path(calleeFileName));
120-
auto lexicalBlockFileAttr = LLVM::DILexicalBlockFileAttr::get(
121-
context, scopeAttr, calleeFileAttr, /*discriminator=*/0);
122-
Location loc = calleeLoc;
123-
if (mlir::isa<CallSiteLoc>(calleeLoc)) {
124-
auto nestedLoc = mlir::cast<CallSiteLoc>(calleeLoc).getCallee();
125-
loc = getNestedLoc(op, lexicalBlockFileAttr, nestedLoc);
126-
}
127-
return FusedLoc::get(context, {loc}, lexicalBlockFileAttr);
128-
}
129-
130113
void setLexicalBlockFileAttr(Operation *op) {
131-
auto opLoc = op->getLoc();
132-
if (auto callSiteLoc = dyn_cast<CallSiteLoc>(opLoc)) {
133-
auto callerLoc = callSiteLoc.getCaller();
134-
auto calleeLoc = callSiteLoc.getCallee();
135-
LLVM::DIScopeAttr scopeAttr;
136-
// We assemble the full inline stack so the parent of this loc must be a
137-
// function
138-
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
139-
auto funcOpLoc = mlir::cast<FusedLoc>(funcOp.getLoc());
140-
scopeAttr = mlir::cast<LLVM::DISubprogramAttr>(funcOpLoc.getMetadata());
141-
auto loc =
142-
CallSiteLoc::get(getNestedLoc(op, scopeAttr, calleeLoc), callerLoc);
143-
op->setLoc(loc);
144-
}
114+
Location opLoc = op->getLoc();
115+
if (!isa<CallSiteLoc>(opLoc))
116+
return;
117+
118+
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
119+
auto funcOpLoc = mlir::cast<FusedLoc>(funcOp.getLoc());
120+
auto scopeAttr =
121+
mlir::cast<LLVM::DISubprogramAttr>(funcOpLoc.getMetadata());
122+
123+
MLIRContext *ctx = op->getContext();
124+
std::function<Location(Location)> makeScoped =
125+
[&](Location loc) -> Location {
126+
if (auto cs = dyn_cast<CallSiteLoc>(loc)) {
127+
Location newCallee = makeScoped(cs.getCallee());
128+
Location newCaller = makeScoped(cs.getCaller());
129+
return CallSiteLoc::get(newCallee, newCaller);
130+
}
131+
132+
// Build a DIFile for this leaf location
133+
FileLineColLoc fileLine = extractFileLoc(loc);
134+
StringRef inputFilePath = fileLine.getFilename().getValue();
135+
LLVM::DIFileAttr fileAttr =
136+
LLVM::DIFileAttr::get(ctx, llvm::sys::path::filename(inputFilePath),
137+
llvm::sys::path::parent_path(inputFilePath));
138+
139+
auto lexicalBlock =
140+
LLVM::DILexicalBlockFileAttr::get(ctx, scopeAttr, fileAttr,
141+
/*discriminator=*/0);
142+
return FusedLoc::get(ctx, {loc}, lexicalBlock);
143+
};
144+
145+
op->setLoc(makeScoped(opLoc));
145146
}
146147

147148
void runOnOperation() override {

test/Conversion/amd/async-ops-alias-scopes.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 --convert-scf-to-cf --convert-builtin-func-to-llvm | FileCheck %s --check-prefixes=COMMON,GFX950
2-
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --convert-scf-to-cf --convert-builtin-func-to-llvm | FileCheck %s --check-prefixes=COMMON,GFX942
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 --convert-scf-to-cf | FileCheck %s --check-prefixes=COMMON,GFX950
2+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --convert-scf-to-cf | FileCheck %s --check-prefixes=COMMON,GFX942
33

44
// COMMON: [[$ASYNC_COPY_SCOPE:#.*]] = #llvm.alias_scope<id = "amdgpu.AsyncCopies"
55
// COMMON: [[$LOCAL_LOAD_SCOPE:#.*]] = #llvm.alias_scope<id = "amdgpu.LocalLoads"

test/Conversion/amd/async_ops_to_llvm.mlir

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,22 +176,26 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
176176
// CHECK: llvm.cond_br
177177
// CHECK: rocdl.global.load.lds
178178
// CHECK-NEXT: llvm.br
179-
// CHECK: _predicated_store
179+
// CHECK: llvm.cond_br
180+
// CHECK: llvm.store
180181

181182
// CHECK: llvm.cond_br
182183
// CHECK: rocdl.global.load.lds
183184
// CHECK-NEXT: llvm.br
184-
// CHECK: _predicated_store
185+
// CHECK: llvm.cond_br
186+
// CHECK: llvm.store
185187

186188
// CHECK: llvm.cond_br
187189
// CHECK: rocdl.global.load.lds
188190
// CHECK-NEXT: llvm.br
189-
// CHECK: _predicated_store
191+
// CHECK: llvm.cond_br
192+
// CHECK: llvm.store
190193

191194
// CHECK: llvm.cond_br
192195
// CHECK: rocdl.global.load.lds
193196
// CHECK-NEXT: llvm.br
194-
// CHECK: _predicated_store
197+
// CHECK: llvm.cond_br
198+
// CHECK: llvm.store
195199

196200
%2 = ttg.async_copy_global_to_local %1, %arg2 mask %67 other %cst_0 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
197201
tt.return
@@ -236,28 +240,32 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
236240
// CHECK: llvm.cond_br
237241
// CHECK: rocdl.global.load.lds
238242
// CHECK-NEXT: llvm.br
239-
// CHECK: _predicated_store
243+
// CHECK: llvm.cond_br
244+
// CHECK: llvm.store
240245

241246
// CHECK: rocdl.ds_bpermute
242247
// CHECK: rocdl.ballot
243248
// CHECK: llvm.cond_br
244249
// CHECK: rocdl.global.load.lds
245250
// CHECK-NEXT: llvm.br
246-
// CHECK: _predicated_store
251+
// CHECK: llvm.cond_br
252+
// CHECK: llvm.store
247253

248254
// CHECK: rocdl.ds_bpermute
249255
// CHECK: rocdl.ballot
250256
// CHECK: llvm.cond_br
251257
// CHECK: rocdl.global.load.lds
252258
// CHECK-NEXT: llvm.br
253-
// CHECK: _predicated_store
259+
// CHECK: llvm.cond_br
260+
// CHECK: llvm.store
254261

255262
// CHECK: rocdl.ds_bpermute
256263
// CHECK: rocdl.ballot
257264
// CHECK: llvm.cond_br
258265
// CHECK: rocdl.global.load.lds
259266
// CHECK-NEXT: llvm.br
260-
// CHECK: _predicated_store
267+
// CHECK: llvm.cond_br
268+
// CHECK: llvm.store
261269

262270
%2 = ttg.async_copy_global_to_local %1, %arg2 mask %67 other %cst_0 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
263271
tt.return

test/Conversion/amd/buffer_load_to_local_to_llvm.mlir

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,19 +147,25 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
147147
// Note that mask/other alignment is 1 so we need 4 conditionals
148148

149149
// COMMON: rocdl.raw.ptr.buffer.load.lds
150-
// COMMON: _predicated_store
150+
// COMMON: llvm.cond_br
151+
// COMMON: llvm.store
151152

152153
// COMMON: rocdl.raw.ptr.buffer.load.lds
153-
// COMMON: _predicated_store
154+
// COMMON: llvm.cond_br
155+
// COMMON: llvm.store
154156

155157
// COMMON: rocdl.raw.ptr.buffer.load.lds
156-
// COMMON: _predicated_store
158+
// COMMON: llvm.cond_br
159+
// COMMON: llvm.store
157160

158161
// COMMON: rocdl.raw.ptr.buffer.load.lds
159-
// COMMON: _predicated_store
162+
// COMMON: llvm.cond_br
163+
// COMMON: llvm.store
160164

161165
// COMMON-NOT: rocdl.raw.ptr.buffer.load.lds
162166
// COMMON-NOT: _predicated_store
167+
// COMMON-NOT: llvm.cond_br
168+
// COMMON-NOT: llvm.store
163169

164170
amdgpu.buffer_load_to_local %arg1[%arg2] mask=%67 other=%cst_0 into %arg3 : <f32>[tensor<32x32xi32, #blocked>] tensor<32x32xf32, #blocked> -> <32x32xf32, #shared, #smem, mutable>
165171
tt.return
@@ -257,22 +263,26 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
257263
// COMMON: rocdl.ds_bpermute
258264
// COMMON: rocdl.ballot
259265
// COMMON: rocdl.raw.ptr.buffer.load.lds
260-
// COMMON: _predicated_store
266+
// COMMON: llvm.cond_br
267+
// COMMON: llvm.store
261268

262269
// COMMON: rocdl.ds_bpermute
263270
// COMMON: rocdl.ballot
264271
// COMMON: rocdl.raw.ptr.buffer.load.lds
265-
// COMMON: _predicated_store
272+
// COMMON: llvm.cond_br
273+
// COMMON: llvm.store
266274

267275
// COMMON: rocdl.ds_bpermute
268276
// COMMON: rocdl.ballot
269277
// COMMON: rocdl.raw.ptr.buffer.load.lds
270-
// COMMON: _predicated_store
278+
// COMMON: llvm.cond_br
279+
// COMMON: llvm.store
271280

272281
// COMMON: rocdl.ds_bpermute
273282
// COMMON: rocdl.ballot
274283
// COMMON: rocdl.raw.ptr.buffer.load.lds
275-
// COMMON: _predicated_store
284+
// COMMON: llvm.cond_br
285+
// COMMON: llvm.store
276286

277287
// COMMON-NOT: rocdl.ds_bpermute
278288
// COMMON-NOT: rocdl.ballot

test/Proton/amd/protongpu_to_llvm.mlir

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ module attributes {"ttg.num-warps" = 8 : i32, ttg.profile_scratch_memory_alignme
8282
// CHECK-DAG: rocdl.workgroup.id.z
8383
// CHECK-DAG: rocdl.grid.dim.x
8484
// CHECK-DAG: rocdl.grid.dim.y
85-
// CHECK-DAG: %[[PID:.*]] = llvm.trunc %15 : i64 to i32
85+
// CHECK-DAG: %[[PID:.*]] = llvm.trunc %{{.*}} : i64 to i32
8686
// CHECK-DAG: %[[SIZE:.*]] = llvm.mlir.constant(384 : i32)
8787
// CHECK-DAG: %{{.*}} = llvm.mul %[[PID]], %[[SIZE]] : i32
8888
%1 = proton_gpu.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32, offset = 0 : i32} : !tt.ptr<i32>
@@ -91,37 +91,24 @@ module attributes {"ttg.num-warps" = 8 : i32, ttg.profile_scratch_memory_alignme
9191
}
9292

9393
// -----
94-
9594
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
9695
#smem = #ttg.shared_memory
9796
module attributes {"ttg.num-warps" = 8 : i32, ttg.profile_scratch_memory_alignment = 128 : i32, ttg.profile_scratch_memory_size = 384 : i32} {
9897
// CHECK-LABEL: convert_smem_finalize
9998
// CHECK: llvm.inline_asm asm_dialect = att operand_attrs = [] "s_getreg_b32 $0, hwreg(HW_REG_XCC_ID, 0, 3)", "=s" : () -> i32
10099
// CHECK: llvm.inline_asm asm_dialect = att operand_attrs = [] "s_getreg_b32 $0, hwreg(HW_REG_HW_ID, 8, 4)", "=s" : () -> i32
101100
// CHECK: llvm.inline_asm asm_dialect = att operand_attrs = [] "s_getreg_b32 $0, hwreg(HW_REG_HW_ID, 13, 3)", "=s" : () -> i32
102-
// CONVERT-BUILTIN: llvm.cond_br %{{.*}}, ^bb1, ^bb9
103-
// CONVERT-BUILTIN: ^bb1: // pred: ^bb0
101+
// CONVERT-BUILTIN: llvm.cond_br %{{.*}}, ^bb1, ^bb3
102+
// CONVERT-BUILTIN: ^bb1:
104103
// CONVERT-BUILTIN: llvm.store %{{.*}}, %{{.*}} : i32, !llvm.ptr<1>
105104
// CONVERT-BUILTIN: llvm.br ^bb2(%{{.*}} : i32)
106-
// CONVERT-BUILTIN: ^bb2(%{{.*}}: i32): // 2 preds: ^bb1, ^bb8
107-
// CONVERT-BUILTIN: llvm.cond_br %1, ^bb3, ^bb4
108-
// CONVERT-BUILTIN: bb3: // pred: ^bb2
109-
// CONVERT-BUILTIN: %{{.*}} = llvm.load %{{.*}} : !llvm.ptr<3> -> i32
110-
// CONVERT-BUILTIN: llvm.br ^bb5(%{{.*}} : i32)
111-
// CONVERT-BUILTIN: ^bb4: // pred: ^bb2
112-
// CONVERT-BUILTIN: llvm.br ^bb5(%{{.*}} : i32)
113-
// CONVERT-BUILTIN: ^bb5(%{{.*}}: i32): // 2 preds: ^bb3, ^bb4
105+
// CONVERT-BUILTIN: ^bb2(%{{.*}}: i32):
106+
// CONVERT-BUILTIN: llvm.load %{{.*}} : !llvm.ptr<3> -> i32
114107
// CONVERT-BUILTIN: llvm.store %{{.*}}, %{{.*}} : i32, !llvm.ptr<1>
115-
// CONVERT-BUILTIN: llvm.cond_br %{{.*}}, ^bb6, ^bb7
116-
// CONVERT-BUILTIN: ^bb6: // pred: ^bb5
117-
// CONVERT-BUILTIN: %{{.*}} = llvm.load %{{.*}} : !llvm.ptr<3> -> i32
118-
// CONVERT-BUILTIN: llvm.br ^bb8(%{{.*}} : i32)
119-
// CONVERT-BUILTIN: ^bb7: // pred: ^bb5
120-
// CONVERT-BUILTIN: llvm.br ^bb8(%{{.*}} : i32)
121-
// CONVERT-BUILTIN: ^bb8(%{{.*}}: i32): // 2 preds: ^bb6, ^bb7
108+
// CONVERT-BUILTIN: llvm.load %{{.*}} : !llvm.ptr<3> -> i32
122109
// CONVERT-BUILTIN: llvm.store %{{.*}}, %{{.*}} : i32, !llvm.ptr<1>
123-
// CONVERT-BUILTIN: llvm.cond_br %{{.*}}, ^bb2(%{{.*}} : i32), ^bb9
124-
// CONVERT-BUILTIN: ^bb9: // 2 preds: ^bb0, ^bb8
110+
// CONVERT-BUILTIN: llvm.cond_br %{{.*}}, ^bb2(%{{.*}} : i32), ^bb3
111+
// CONVERT-BUILTIN: ^bb3:
125112
// CHECK: llvm.return
126113
llvm.func @convert_smem_finalize(%arg: !llvm.ptr<1>) attributes {noinline = false, nvvm.kernel = 1 : ui1} {
127114
%0 = ttg.local_alloc : () -> !ttg.memdesc<512xi32, #shared, #smem, mutable>

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,56 @@ def TTG_UpcastMXFPOp : TT_AMDGPU_Op<"upcast_mxfp", [Pure]> {
519519
}];
520520
}
521521

522+
//===----------------------------------------------------------------------===//
523+
// MaskedLoadOp
524+
//===----------------------------------------------------------------------===//
525+
def MaskedLoadOp : TT_AMDGPU_Op<"masked_load", []> {
526+
let summary = "Masked load operation";
527+
let description = [{
528+
Load operation with masking support. If the mask is true, loads from the given pointer. Works with LLVM types as a utility op for making LLVM conversion easier.
529+
}];
530+
let arguments = (ins
531+
LLVM_AnyPointer:$ptr,
532+
I1:$mask,
533+
LLVM_Type:$falseVal,
534+
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
535+
DefaultValuedAttr<BoolAttr, "false">:$forceNoAlias
536+
);
537+
538+
let results = (outs LLVM_Type:$result);
539+
540+
let assemblyFormat = [{
541+
$ptr `,` $mask `,` $falseVal
542+
oilist(`cacheModifier` `=` $cache)
543+
(`forceNoAlias` $forceNoAlias^)?
544+
attr-dict `:` functional-type(operands, results)
545+
}];
546+
}
547+
548+
//===----------------------------------------------------------------------===//
549+
// MaskedStoreOp
550+
//===----------------------------------------------------------------------===//
551+
def MaskedStoreOp : TT_AMDGPU_Op<"masked_store", []> {
552+
let summary = "Masked Store operation";
553+
let description = [{
554+
Store operation with masking support. If the mask is true, Store from the given pointer. Works with LLVM types as a utility op for making LLVM conversion easier.
555+
}];
556+
let arguments = (ins
557+
LLVM_AnyPointer:$ptr,
558+
LLVM_Type:$value,
559+
I1:$mask,
560+
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
561+
DefaultValuedAttr<BoolAttr, "false">:$forceNoAlias
562+
);
563+
564+
let assemblyFormat = [{
565+
$ptr `,` $value `,` $mask
566+
oilist(`cacheModifier` `=` $cache)
567+
(`forceNoAlias` $forceNoAlias^)?
568+
attr-dict `:` type(operands)
569+
}];
570+
}
571+
522572
//===----------------------------------------------------------------------===//
523573
// ScaledUpcastFp4Op
524574
//===----------------------------------------------------------------------===//
@@ -579,7 +629,6 @@ def ScaledUpcastFp8Op : TT_AMDGPU_Op<"scaled_upcast_fp8", [
579629
`:` type($input) `,` type($scale) `->` type($output)
580630
}];
581631
}
582-
583632
//===----------------------------------------------------------------------===//
584633
// InThreadTransposeOp
585634
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)