Skip to content

Commit 31e21cc

Browse files
Merge OpenAI Triton commit 711caa4 (#3845)
This PR change the Triton base from ff4f0bd to 711caa4 (Apr 4). Pass rate: 90.8%
2 parents 1d6ae76 + 71e0eb7 commit 31e21cc

File tree

62 files changed

+1043
-472
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+1043
-472
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,7 +1314,7 @@ def TT_DescriptorStoreOp : TT_Op<"descriptor_store", [
13141314
def TT_DescriptorGatherOp : TT_Op<"descriptor_gather", [MemoryEffects<[MemRead<GlobalMemory>]>]> {
13151315
let summary = "gather multiple rows from a descriptor into a single tensor";
13161316
let description = [{
1317-
The `tt.desciptor_gather` op will be lowered to NVIDIA TMA
1317+
The `tt.descriptor_gather` op will be lowered to NVIDIA TMA
13181318
load operations on targets that support it.
13191319

13201320
`desc_ptr` is a pointer to the TMA descriptor allocated in global memory.
@@ -1340,9 +1340,10 @@ def TT_DescriptorGatherOp : TT_Op<"descriptor_gather", [MemoryEffects<[MemRead<G
13401340
let hasVerifier = 1;
13411341

13421342
let extraClassDeclaration = [{
1343-
// TMA gathers have resstrictions on the minimum size of the gather result.
1343+
// TMA gathers have restrictions on the minimum size of the gather result.
13441344
// This function verifies the result type.
1345-
static LogicalResult verifyResultType(Operation *op, mlir::ShapedType type);
1345+
static LogicalResult verifyResultType(Operation *op, ShapedType resultType,
1346+
RankedTensorType indicesType);
13461347
}];
13471348
}
13481349

@@ -1360,6 +1361,8 @@ def TT_DescriptorScatterOp : TT_Op<"descriptor_scatter", [
13601361
$desc `[` $x_offsets `,` $y_offset `]` `,` $src
13611362
attr-dict `:` type(operands)
13621363
}];
1364+
1365+
let hasVerifier = 1;
13631366
}
13641367

13651368
def TT_ExperimentalTensormapCreateOp: TT_Op<

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ template <> struct hash<CacheKey> {
3939

4040
namespace mlir::triton::gpu {
4141

42+
constexpr static char AttrMaxRegistersName[] = "ttg.maxnreg";
4243
constexpr static char AttrNumWarpsName[] = "ttg.num-warps";
4344
constexpr static char AttrNumCTAsName[] = "ttg.num-ctas";
4445
constexpr static char AttrTargetName[] = "ttg.target";

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,9 @@ def TTG_WarpSpecializeOp : TTG_Op<"warp_specialize", [
376376
let arguments = (ins
377377
Variadic<AnyType>:$explicitCaptures,
378378
DenseI32ArrayAttr:$partitionNumWarps,
379-
OptionalAttr<DenseI32ArrayAttr>:$warpGroupStartIds
379+
OptionalAttr<DenseI32ArrayAttr>:$warpGroupStartIds,
380+
OptionalAttr<DenseI32ArrayAttr>:$requestedRegisters,
381+
OptionalAttr<DenseI32ArrayAttr>:$actualRegisters
380382
);
381383
let results = (outs Variadic<AnyType>:$defaultPassthrough);
382384

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ std::pair<OpResult, int64_t> getDefinitionAndDistance(scf::ForOp forOp,
4949
std::pair<Operation *, int64_t> getDefiningOpAndDistance(scf::ForOp forOp,
5050
Value value);
5151

52-
// Return maxumum length of the vectorized copy between registers and shared
52+
// Return maximum length of the vectorized copy between registers and shared
5353
// memory for the given tensor type and shared encoding.
5454
int getCopyVecBytes(RankedTensorType registerTy,
5555
gpu::SharedEncodingTrait sharedEnc);

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@ void lowerLoops(ModuleOp moduleOp);
2828
/// Pipeline the TMA stores in the loop.
2929
bool pipelineTMAStores(scf::ForOp forOp);
3030

31-
/// Simple pipelining for the MMA ops which accumulator is modified in the loop.
32-
scf::ForOp pipelineMMAWithScaledAcc(scf::ForOp forOp);
33-
3431
/// This does post-processing on the pipelined loop to try to pipeline wgmma
3532
/// ops.
3633
// TODO: this should be included as part of the pipeline but currently the wgmma
@@ -75,6 +72,8 @@ class CoarseSchedule {
7572
}
7673

7774
bool isBefore(iterator a, iterator b) const {
75+
if (a == b)
76+
return false;
7877
for (auto it = begin(); it != end(); ++it) {
7978
if (it == a)
8079
return true;

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,11 @@ SetVector<Value> getNestedOperands(Operation *op);
237237
// Erase the given loop carried values from the loop, where `loop` is replaced
238238
// with a new loop.
239239
void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices);
240+
241+
// Return true if two value sets may refer to the same allocation.
242+
bool mayAliasAllocations(const DenseSet<Value> &lhs,
243+
const DenseSet<Value> &rhs);
244+
240245
} // namespace mlir
241246

242247
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,8 @@ def TTNG_AsyncTMAScatterOp : TTNG_Op<"async_tma_scatter", [DeclareOpInterfaceMet
366366
$desc_ptr `[` $x_offsets `,` $y_offset `]` $src
367367
attr-dict `:` type(operands)
368368
}];
369+
370+
let hasVerifier = 1;
369371
}
370372

371373
def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait"> {

lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,20 @@ struct AllocateWarpGroups
1818
void runOnOperation() override {
1919
ModuleOp mod = getOperation();
2020

21+
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
22+
23+
struct WarpGroupInfo {
24+
SmallVector<Region *> partitions;
25+
int maxRequestedRegs = 0;
26+
unsigned numWarps = 0;
27+
};
28+
struct WarpGroupPartition {
29+
int startId;
30+
Region *partition;
31+
int32_t estRegs;
32+
int numWarps;
33+
};
34+
2135
// Compute the total number of warps required at any given time.
2236
int baseNumWarps = lookupNumWarps(mod);
2337
int maxExtraWarps = 0;
@@ -42,6 +56,81 @@ struct AllocateWarpGroups
4256
startId += size;
4357
}
4458
op.setWarpGroupStartIds(startIds);
59+
60+
// Require that an estimate has been set and that we have even warpgroups.
61+
auto regsAttr = op.getRequestedRegisters();
62+
if (!regsAttr || op.getTotalPartitionWarps() % 4 != 0)
63+
return;
64+
65+
// Group the partitions into warpgroups.
66+
SmallVector<WarpGroupPartition> orderedPartitions;
67+
for (auto [startId, partition, estRegs, numWarps] :
68+
llvm::zip(startIds, op.getPartitionRegions(), *regsAttr, arr))
69+
orderedPartitions.push_back({startId, partition, estRegs, numWarps});
70+
llvm::sort(orderedPartitions,
71+
[&](auto lhs, auto rhs) { return lhs.startId < rhs.startId; });
72+
73+
// Iterate over the partitions and assign them to warp groups. Determine
74+
// the maximum number of requested registers per warp group.
75+
SmallVector<WarpGroupInfo> warpGroups;
76+
for (auto [startId, partition, estRegs, numWarps] : orderedPartitions) {
77+
if (startId % 4 == 0) {
78+
warpGroups.push_back(WarpGroupInfo{});
79+
}
80+
warpGroups.back().partitions.push_back(partition);
81+
// Round up the nearest multiple of 8.
82+
int estRegsCeil8 = llvm::divideCeil(estRegs, 8) * 8;
83+
warpGroups.back().maxRequestedRegs =
84+
std::max<int>(warpGroups.back().maxRequestedRegs, estRegsCeil8);
85+
warpGroups.back().numWarps += numWarps;
86+
}
87+
88+
// Determine the maximum number of registers per thread. This may have
89+
// been set by the user.
90+
int maxnreg;
91+
if (auto maxnregAttr =
92+
op->getAttrOfType<IntegerAttr>(AttrMaxRegistersName)) {
93+
maxnreg = maxnregAttr.getInt();
94+
} else {
95+
maxnreg = (1 << 16) / (baseNumWarps + op.getTotalPartitionWarps()) /
96+
threadsPerWarp;
97+
maxnreg = maxnreg / 8 * 8;
98+
}
99+
100+
// Compute the register deficit over the partition warp groups.
101+
int registerDeficit = 0;
102+
for (const WarpGroupInfo &wg : warpGroups) {
103+
assert(wg.numWarps % 4 == 0);
104+
registerDeficit +=
105+
(maxnreg - wg.maxRequestedRegs) * wg.numWarps * threadsPerWarp;
106+
}
107+
if (registerDeficit <= 0)
108+
return;
109+
110+
// Determine the number of extra registers that we can distribute to the
111+
// default warp group.
112+
int leftover =
113+
((baseNumWarps * threadsPerWarp * maxnreg) + registerDeficit) /
114+
baseNumWarps / threadsPerWarp;
115+
// Round down to the nearest multiple of 8.
116+
leftover = leftover / 8 * 8;
117+
118+
// Generate setmaxnreg in each partition according to its warp group.
119+
SmallVector<int32_t> maxnregsPerPartition(1 + arr.size());
120+
for (const WarpGroupInfo &wg : warpGroups) {
121+
for (Region *region : wg.partitions) {
122+
maxnregsPerPartition[1 + region->getRegionNumber()] =
123+
wg.maxRequestedRegs;
124+
}
125+
}
126+
// Set the register usage for the default warp group.
127+
maxnregsPerPartition.front() = leftover;
128+
op.setActualRegisters(maxnregsPerPartition);
129+
130+
// Set the initial max number of registers. This is needed for PTXAS to
131+
// cooperate.
132+
mod->setAttr(AttrMaxRegistersName,
133+
Builder(op.getContext()).getI32IntegerAttr(maxnreg));
45134
});
46135

47136
Builder b(&getContext());

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
12
#include "mlir/IR/BuiltinAttributes.h"
23
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
34
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
@@ -124,11 +125,11 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
124125
mlir::IntegerType::get(llvmFuncOp.getContext(), 8);
125126
const auto arrayType = mlir::LLVM::LLVMArrayType::get(
126127
llvmFuncOp.getContext(), byteType, 128);
127-
llvmFuncOp.setArgAttr(i, "llvm.byval",
128+
llvmFuncOp.setArgAttr(i, LLVM::LLVMDialect::getByValAttrName(),
128129
mlir::TypeAttr::get(arrayType));
129-
llvmFuncOp.setArgAttr(i, "nvvm.grid_constant",
130+
llvmFuncOp.setArgAttr(i, NVVM::NVVMDialect::getGridConstantAttrName(),
130131
mlir::UnitAttr::get(llvmFuncOp.getContext()));
131-
llvmFuncOp.setArgAttr(i, "llvm.align",
132+
llvmFuncOp.setArgAttr(i, LLVM::LLVMDialect::getAlignAttrName(),
132133
mlir::IntegerAttr::get(i32_type, 64));
133134
}
134135
}
@@ -154,7 +155,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
154155

155156
if (LLVM::isKernel(funcOp)) {
156157
// Set an attribute to indicate this function is a kernel entry.
157-
newFuncOp->setAttr("nvvm.kernel",
158+
newFuncOp->setAttr(NVVM::NVVMDialect::getKernelFuncAttrName(),
158159
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
159160
newFuncOp.setLinkage(LLVM::Linkage::External);
160161
} else {
@@ -165,13 +166,21 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
165166
ArrayAttr::get(ctx, rewriter.getStringAttr("noinline")));
166167
newFuncOp.setLinkage(LLVM::Linkage::Internal);
167168
}
168-
// Set an attribute for reqntidx, it could be used in latter LLVM codegen
169-
// for `nvvm.annotation` metadata.
169+
170+
// Determine the actual number of required warps.
170171
int numWarps = triton::gpu::lookupNumWarps(funcOp);
171172
if (auto totalNumWarps = funcOp.getParentOp()->getAttrOfType<IntegerAttr>(
172173
"ttg.total-num-warps"))
173174
numWarps = totalNumWarps.getInt();
174-
newFuncOp->setAttr("nvvm.reqntid",
175+
176+
// Set `nvvm.maxnreg` if it was specified on the module.
177+
if (Attribute maxnregAttr =
178+
funcOp.getParentOp()->getAttr(triton::gpu::AttrMaxRegistersName))
179+
newFuncOp->setAttr(NVVM::NVVMDialect::getMaxnregAttrName(), maxnregAttr);
180+
181+
// Set an attribute for reqntidx, it could be used in latter LLVM codegen
182+
// for `nvvm.annotation` metadata.
183+
newFuncOp->setAttr(NVVM::NVVMDialect::getReqntidAttrName(),
175184
rewriter.getDenseI32ArrayAttr(32 * numWarps));
176185

177186
rewriter.eraseOp(funcOp);

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ struct ReduceOpConversion
358358
resultIdx < resultDim; ++resultIdx) {
359359
auto smemIdx = resultIdx < op.getAxis() ? resultIdx : resultIdx + 1;
360360
if (resultShape[resultIdx] > smemShape[smemIdx]) {
361-
// When srcShape smaller then src sizePerThread, only srcShape
361+
// When srcShape smaller than src sizePerThread, only srcShape
362362
// elements is accumulated in smem. Modulo smemShape effectively
363363
// replicates srcShape elements to src sizePerThread.
364364
readIdx[smemIdx] =

0 commit comments

Comments
 (0)