Skip to content

Commit 3ba7d1e

Browse files
Merge commit '05b500cc1e0ef668167174e47dd5ac88e909c245'
2 parents 1d6ae76 + 05b500c commit 3ba7d1e

File tree

42 files changed

+427
-149
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+427
-149
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,7 +1314,7 @@ def TT_DescriptorStoreOp : TT_Op<"descriptor_store", [
13141314
def TT_DescriptorGatherOp : TT_Op<"descriptor_gather", [MemoryEffects<[MemRead<GlobalMemory>]>]> {
13151315
let summary = "gather multiple rows from a descriptor into a single tensor";
13161316
let description = [{
1317-
The `tt.desciptor_gather` op will be lowered to NVIDIA TMA
1317+
The `tt.descriptor_gather` op will be lowered to NVIDIA TMA
13181318
load operations on targets that support it.
13191319

13201320
`desc_ptr` is a pointer to the TMA descriptor allocated in global memory.
@@ -1340,9 +1340,10 @@ def TT_DescriptorGatherOp : TT_Op<"descriptor_gather", [MemoryEffects<[MemRead<G
13401340
let hasVerifier = 1;
13411341

13421342
let extraClassDeclaration = [{
1343-
// TMA gathers have resstrictions on the minimum size of the gather result.
1343+
// TMA gathers have restrictions on the minimum size of the gather result.
13441344
// This function verifies the result type.
1345-
static LogicalResult verifyResultType(Operation *op, mlir::ShapedType type);
1345+
static LogicalResult verifyResultType(Operation *op, ShapedType resultType,
1346+
RankedTensorType indicesType);
13461347
}];
13471348
}
13481349

@@ -1360,6 +1361,8 @@ def TT_DescriptorScatterOp : TT_Op<"descriptor_scatter", [
13601361
$desc `[` $x_offsets `,` $y_offset `]` `,` $src
13611362
attr-dict `:` type(operands)
13621363
}];
1364+
1365+
let hasVerifier = 1;
13631366
}
13641367

13651368
def TT_ExperimentalTensormapCreateOp: TT_Op<

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ std::pair<OpResult, int64_t> getDefinitionAndDistance(scf::ForOp forOp,
4949
std::pair<Operation *, int64_t> getDefiningOpAndDistance(scf::ForOp forOp,
5050
Value value);
5151

52-
// Return maxumum length of the vectorized copy between registers and shared
52+
// Return maximum length of the vectorized copy between registers and shared
5353
// memory for the given tensor type and shared encoding.
5454
int getCopyVecBytes(RankedTensorType registerTy,
5555
gpu::SharedEncodingTrait sharedEnc);

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,11 @@ SetVector<Value> getNestedOperands(Operation *op);
237237
// Erase the given loop carried values from the loop, where `loop` is replaced
238238
// with a new loop.
239239
void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices);
240+
241+
// Return true if two value sets may refer to the same allocation.
242+
bool mayAliasAllocations(const DenseSet<Value> &lhs,
243+
const DenseSet<Value> &rhs);
244+
240245
} // namespace mlir
241246

242247
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,8 @@ def TTNG_AsyncTMAScatterOp : TTNG_Op<"async_tma_scatter", [DeclareOpInterfaceMet
366366
$desc_ptr `[` $x_offsets `,` $y_offset `]` $src
367367
attr-dict `:` type(operands)
368368
}];
369+
370+
let hasVerifier = 1;
369371
}
370372

371373
def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait"> {

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
12
#include "mlir/IR/BuiltinAttributes.h"
23
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
34
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
@@ -171,7 +172,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
171172
if (auto totalNumWarps = funcOp.getParentOp()->getAttrOfType<IntegerAttr>(
172173
"ttg.total-num-warps"))
173174
numWarps = totalNumWarps.getInt();
174-
newFuncOp->setAttr("nvvm.reqntid",
175+
newFuncOp->setAttr(NVVM::NVVMDialect::getReqntidAttrName(),
175176
rewriter.getDenseI32ArrayAttr(32 * numWarps));
176177

177178
rewriter.eraseOp(funcOp);

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ struct ReduceOpConversion
358358
resultIdx < resultDim; ++resultIdx) {
359359
auto smemIdx = resultIdx < op.getAxis() ? resultIdx : resultIdx + 1;
360360
if (resultShape[resultIdx] > smemShape[smemIdx]) {
361-
// When srcShape smaller then src sizePerThread, only srcShape
361+
// When srcShape smaller than src sizePerThread, only srcShape
362362
// elements is accumulated in smem. Modulo smemShape effectively
363363
// replicates srcShape elements to src sizePerThread.
364364
readIdx[smemIdx] =

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
628628
}
629629
// Proton patterns
630630
// NOTE: Because Proton's inputs are scalars and not tensors this conversion
631-
// isn't strictly nessessary however you could envision a case where we pass in
631+
// isn't strictly necessary however you could envision a case where we pass in
632632
// tensors in for Triton object specific tracing operations in which case we
633633
// would need to fill in the OpConversionPattern
634634
void populateProtonPatterns(TritonGPUTypeConverter &typeConverter,

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,73 +1245,88 @@ LogicalResult GatherOp::inferReturnTypes(
12451245
}
12461246

12471247
// -- DescriptorGatherOp
1248-
LogicalResult DescriptorGatherOp::verifyResultType(Operation *op,
1249-
mlir::ShapedType type) {
1250-
if (type.getRank() != 2)
1251-
return op->emitOpError("result must be a 2D tensor, but got ") << type;
1248+
LogicalResult
1249+
DescriptorGatherOp::verifyResultType(Operation *op, ShapedType resultType,
1250+
RankedTensorType indicesType) {
1251+
if (indicesType.getRank() != 1)
1252+
return op->emitOpError("x offsets must be a 1D tensor, but got ")
1253+
<< indicesType;
1254+
if (resultType.getRank() != 2)
1255+
return op->emitOpError("result must be a 2D tensor, but got ")
1256+
<< resultType;
12521257

12531258
// The swizzling of TMA accesses matches that of the MMAv3 shared memory
12541259
// layouts. However, these have minimum size requirements.
12551260
// TODO: We can support smaller gather sizes by padding the `local_alloc` this
12561261
// lowers to to the nearest minimum tile size.
1257-
if (unsigned rows = type.getShape()[0]; rows < 8) {
1262+
if (unsigned rows = resultType.getShape()[0]; rows < 8) {
12581263
return op->emitOpError("gather must have at least 8 rows, but got ")
12591264
<< rows;
12601265
}
12611266

1262-
Type dtype = type.getElementType();
1267+
Type dtype = resultType.getElementType();
12631268
if (dtype.getIntOrFloatBitWidth() > 32)
12641269
return op->emitOpError("TMA dtype cannot be greater than 32 bits");
12651270

12661271
unsigned minCols = 32 / dtype.getIntOrFloatBitWidth() * 8;
1267-
if (unsigned cols = type.getShape()[1]; cols < minCols) {
1272+
if (unsigned cols = resultType.getShape()[1]; cols < minCols) {
12681273
return op->emitOpError("gather of ")
12691274
<< dtype << " must have at least " << minCols << " columns, but got "
12701275
<< cols;
12711276
}
12721277

1278+
if (resultType.getShape()[0] != indicesType.getShape()[0]) {
1279+
return op->emitOpError("result tensor must have as many rows as indices (")
1280+
<< indicesType.getShape()[0] << "), but got " << resultType;
1281+
}
1282+
12731283
return success();
12741284
}
12751285

1276-
LogicalResult DescriptorGatherOp::verify() {
1277-
RankedTensorType blockType = getDesc().getType().getBlockType();
1286+
static LogicalResult verifyGatherScatterOp(Operation *op,
1287+
RankedTensorType blockType,
1288+
RankedTensorType resultType,
1289+
RankedTensorType indicesType) {
12781290
// Gather from `!tt.tensordesc<tensor<1xMxdtype>>`.
1279-
if (blockType.getRank() != 2)
1280-
return emitOpError("block must be a 2D tensor, but got ") << blockType;
1281-
if (blockType.getShape()[0] != 1)
1282-
return emitOpError("block must have exactly 1 row, but got ") << blockType;
1283-
1284-
// With x offsets `tensor<Nxinttype>`.
1285-
RankedTensorType indicesType = getXOffsets().getType();
1286-
if (indicesType.getRank() != 1)
1287-
return emitOpError("x offsets must be a 1D tensor, but got ")
1288-
<< indicesType;
1291+
if (blockType.getRank() != 2) {
1292+
return op->emitOpError("block must be a 2D tensor, but got ") << blockType;
1293+
}
1294+
if (blockType.getShape()[0] != 1) {
1295+
return op->emitOpError("block must have exactly 1 row, but got ")
1296+
<< blockType;
1297+
}
12891298

1290-
// Into `tensor<NxMxdtype>`.
1291-
RankedTensorType resultType = getType();
1292-
if (failed(verifyResultType(*this, resultType)))
1299+
// With x offsets `tensor<Nxinttype>` into `tensor<NxMxdtype>`.
1300+
if (failed(DescriptorGatherOp::verifyResultType(op, resultType, indicesType)))
12931301
return failure();
12941302

1295-
if (resultType.getShape()[0] != indicesType.getShape()[0]) {
1296-
return emitOpError("result tensor must have as many rows as indices (")
1297-
<< indicesType.getShape()[0] << "), but got " << resultType;
1298-
}
12991303
if (resultType.getShape()[1] != blockType.getShape()[1]) {
1300-
return emitOpError("result tensor number of columns must match block (")
1304+
return op->emitOpError("result tensor number of columns must match block (")
13011305
<< blockType.getShape()[1] << "), but got " << resultType;
13021306
}
13031307
if (resultType.getElementType() != blockType.getElementType()) {
1304-
return emitOpError("result tensor element type must match block (")
1308+
return op->emitOpError("result tensor element type must match block (")
13051309
<< blockType.getElementType() << "), but got " << resultType;
13061310
}
13071311

13081312
return success();
13091313
}
13101314

1315+
LogicalResult DescriptorGatherOp::verify() {
1316+
return verifyGatherScatterOp(*this, getDesc().getType().getBlockType(),
1317+
getResult().getType(), getXOffsets().getType());
1318+
}
1319+
1320+
// -- DescriptorScatterOp --
1321+
LogicalResult DescriptorScatterOp::verify() {
1322+
return verifyGatherScatterOp(*this, getDesc().getType().getBlockType(),
1323+
getSrc().getType(), getXOffsets().getType());
1324+
}
1325+
13111326
// -- DescriptorLoadOp --
1312-
static LogicalResult verifyDesciptorLoadStoreType(Operation *op,
1313-
TensorDescType desc,
1314-
RankedTensorType tensor) {
1327+
static LogicalResult verifyDescriptorLoadStoreType(Operation *op,
1328+
TensorDescType desc,
1329+
RankedTensorType tensor) {
13151330
RankedTensorType block = desc.getBlockType();
13161331
ArrayRef<int64_t> blockShape = block.getShape();
13171332
ArrayRef<int64_t> tensorShape = tensor.getShape();
@@ -1328,17 +1343,17 @@ static LogicalResult verifyDesciptorLoadStoreType(Operation *op,
13281343
if (blockShape == tensorShape &&
13291344
block.getElementType() == tensor.getElementType())
13301345
return success();
1331-
return op->emitOpError("tensor desciptor block and tensor types must match");
1346+
return op->emitOpError("tensor descriptor block and tensor types must match");
13321347
}
13331348

13341349
LogicalResult DescriptorLoadOp::verify() {
1335-
return verifyDesciptorLoadStoreType(*this, getDesc().getType(), getType());
1350+
return verifyDescriptorLoadStoreType(*this, getDesc().getType(), getType());
13361351
}
13371352

13381353
// -- DescriptorStoreOp --
13391354
LogicalResult DescriptorStoreOp::verify() {
1340-
return verifyDesciptorLoadStoreType(*this, getDesc().getType(),
1341-
getSrc().getType());
1355+
return verifyDescriptorLoadStoreType(*this, getDesc().getType(),
1356+
getSrc().getType());
13421357
}
13431358

13441359
// -- ExperimentalTensormapCreateOp --

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ static LogicalResult parseBoolAttrValue(AsmParser &parser, Attribute attr,
502502
bool &value, StringRef desc) {
503503
auto boolAttr = mlir::dyn_cast<BoolAttr>(attr);
504504
if (!boolAttr) {
505-
parser.emitError(parser.getNameLoc(), "expected an bool type in ") << desc;
505+
parser.emitError(parser.getNameLoc(), "expected a bool type in ") << desc;
506506
return failure();
507507
}
508508
value = boolAttr.getValue();
@@ -2798,7 +2798,7 @@ std::string getSharedLayoutStr(RankedTensorType tensorType,
27982798
// Shared layouts are a mapping of (block, offset) --> (...)
27992799

28002800
// We can just use a single int to index into elementMapping because
2801-
// the 'swizzle' operation rearranges the indicies---and we want to keep it
2801+
// the 'swizzle' operation rearranges the indices---and we want to keep it
28022802
// that way
28032803
int32_t idx = 0;
28042804
// Enumerate all the offsets for each block

lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ static scf::IfOp eraseIfResults(ImplicitLocOpBuilder &b, scf::IfOp ifOp,
461461
// epilogueK and the first iteration of bodyj(K+1). Hence the `- N` term in the
462462
// total number of iterations.
463463
//
464-
// What the above Python-psuedo-code glosses over is SSA dependency management.
464+
// What the above Python-pseudo-code glosses over is SSA dependency management.
465465
// To interpret the pseudocode as SSA IR, just imagine everything is put back
466466
// into allocas and SSA formation re-runs after fusion, which one should note
467467
// will introduce undefs.

0 commit comments

Comments
 (0)