Skip to content

Commit 05b500c

Browse files
authored
[TritonGPU] Misc cleanups (#6402)
These are some diffs split from other patches that didn't quite make it. * Adds verifiers for TMA scatter ops * Some misc code cleanup
1 parent 1fe4aa5 commit 05b500c

File tree

6 files changed

+62
-33
lines changed

6 files changed

+62
-33
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,9 +1340,10 @@ def TT_DescriptorGatherOp : TT_Op<"descriptor_gather", [MemoryEffects<[MemRead<G
13401340
let hasVerifier = 1;
13411341

13421342
let extraClassDeclaration = [{
1343-
// TMA gathers have resstrictions on the minimum size of the gather result.
1343+
// TMA gathers have restrictions on the minimum size of the gather result.
13441344
// This function verifies the result type.
1345-
static LogicalResult verifyResultType(Operation *op, mlir::ShapedType type);
1345+
static LogicalResult verifyResultType(Operation *op, ShapedType resultType,
1346+
RankedTensorType indicesType);
13461347
}];
13471348
}
13481349

@@ -1360,6 +1361,8 @@ def TT_DescriptorScatterOp : TT_Op<"descriptor_scatter", [
13601361
$desc `[` $x_offsets `,` $y_offset `]` `,` $src
13611362
attr-dict `:` type(operands)
13621363
}];
1364+
1365+
let hasVerifier = 1;
13631366
}
13641367

13651368
def TT_ExperimentalTensormapCreateOp: TT_Op<

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,8 @@ def TTNG_AsyncTMAScatterOp : TTNG_Op<"async_tma_scatter", [DeclareOpInterfaceMet
366366
$desc_ptr `[` $x_offsets `,` $y_offset `]` $src
367367
attr-dict `:` type(operands)
368368
}];
369+
370+
let hasVerifier = 1;
369371
}
370372

371373
def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait"> {

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
12
#include "mlir/IR/BuiltinAttributes.h"
23
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
34
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
@@ -171,7 +172,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
171172
if (auto totalNumWarps = funcOp.getParentOp()->getAttrOfType<IntegerAttr>(
172173
"ttg.total-num-warps"))
173174
numWarps = totalNumWarps.getInt();
174-
newFuncOp->setAttr("nvvm.reqntid",
175+
newFuncOp->setAttr(NVVM::NVVMDialect::getReqntidAttrName(),
175176
rewriter.getDenseI32ArrayAttr(32 * numWarps));
176177

177178
rewriter.eraseOp(funcOp);

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,69 +1245,84 @@ LogicalResult GatherOp::inferReturnTypes(
12451245
}
12461246

12471247
// -- DescriptorGatherOp
1248-
LogicalResult DescriptorGatherOp::verifyResultType(Operation *op,
1249-
mlir::ShapedType type) {
1250-
if (type.getRank() != 2)
1251-
return op->emitOpError("result must be a 2D tensor, but got ") << type;
1248+
LogicalResult
1249+
DescriptorGatherOp::verifyResultType(Operation *op, ShapedType resultType,
1250+
RankedTensorType indicesType) {
1251+
if (indicesType.getRank() != 1)
1252+
return op->emitOpError("x offsets must be a 1D tensor, but got ")
1253+
<< indicesType;
1254+
if (resultType.getRank() != 2)
1255+
return op->emitOpError("result must be a 2D tensor, but got ")
1256+
<< resultType;
12521257

12531258
// The swizzling of TMA accesses matches that of the MMAv3 shared memory
12541259
// layouts. However, these have minimum size requirements.
12551260
// TODO: We can support smaller gather sizes by padding the `local_alloc` this
12561261
// lowers to to the nearest minimum tile size.
1257-
if (unsigned rows = type.getShape()[0]; rows < 8) {
1262+
if (unsigned rows = resultType.getShape()[0]; rows < 8) {
12581263
return op->emitOpError("gather must have at least 8 rows, but got ")
12591264
<< rows;
12601265
}
12611266

1262-
Type dtype = type.getElementType();
1267+
Type dtype = resultType.getElementType();
12631268
if (dtype.getIntOrFloatBitWidth() > 32)
12641269
return op->emitOpError("TMA dtype cannot be greater than 32 bits");
12651270

12661271
unsigned minCols = 32 / dtype.getIntOrFloatBitWidth() * 8;
1267-
if (unsigned cols = type.getShape()[1]; cols < minCols) {
1272+
if (unsigned cols = resultType.getShape()[1]; cols < minCols) {
12681273
return op->emitOpError("gather of ")
12691274
<< dtype << " must have at least " << minCols << " columns, but got "
12701275
<< cols;
12711276
}
12721277

1278+
if (resultType.getShape()[0] != indicesType.getShape()[0]) {
1279+
return op->emitOpError("result tensor must have as many rows as indices (")
1280+
<< indicesType.getShape()[0] << "), but got " << resultType;
1281+
}
1282+
12731283
return success();
12741284
}
12751285

1276-
LogicalResult DescriptorGatherOp::verify() {
1277-
RankedTensorType blockType = getDesc().getType().getBlockType();
1286+
static LogicalResult verifyGatherScatterOp(Operation *op,
1287+
RankedTensorType blockType,
1288+
RankedTensorType resultType,
1289+
RankedTensorType indicesType) {
12781290
// Gather from `!tt.tensordesc<tensor<1xMxdtype>>`.
1279-
if (blockType.getRank() != 2)
1280-
return emitOpError("block must be a 2D tensor, but got ") << blockType;
1281-
if (blockType.getShape()[0] != 1)
1282-
return emitOpError("block must have exactly 1 row, but got ") << blockType;
1283-
1284-
// With x offsets `tensor<Nxinttype>`.
1285-
RankedTensorType indicesType = getXOffsets().getType();
1286-
if (indicesType.getRank() != 1)
1287-
return emitOpError("x offsets must be a 1D tensor, but got ")
1288-
<< indicesType;
1291+
if (blockType.getRank() != 2) {
1292+
return op->emitOpError("block must be a 2D tensor, but got ") << blockType;
1293+
}
1294+
if (blockType.getShape()[0] != 1) {
1295+
return op->emitOpError("block must have exactly 1 row, but got ")
1296+
<< blockType;
1297+
}
12891298

1290-
// Into `tensor<NxMxdtype>`.
1291-
RankedTensorType resultType = getType();
1292-
if (failed(verifyResultType(*this, resultType)))
1299+
// With x offsets `tensor<Nxinttype>` into `tensor<NxMxdtype>`.
1300+
if (failed(DescriptorGatherOp::verifyResultType(op, resultType, indicesType)))
12931301
return failure();
12941302

1295-
if (resultType.getShape()[0] != indicesType.getShape()[0]) {
1296-
return emitOpError("result tensor must have as many rows as indices (")
1297-
<< indicesType.getShape()[0] << "), but got " << resultType;
1298-
}
12991303
if (resultType.getShape()[1] != blockType.getShape()[1]) {
1300-
return emitOpError("result tensor number of columns must match block (")
1304+
return op->emitOpError("result tensor number of columns must match block (")
13011305
<< blockType.getShape()[1] << "), but got " << resultType;
13021306
}
13031307
if (resultType.getElementType() != blockType.getElementType()) {
1304-
return emitOpError("result tensor element type must match block (")
1308+
return op->emitOpError("result tensor element type must match block (")
13051309
<< blockType.getElementType() << "), but got " << resultType;
13061310
}
13071311

13081312
return success();
13091313
}
13101314

1315+
LogicalResult DescriptorGatherOp::verify() {
1316+
return verifyGatherScatterOp(*this, getDesc().getType().getBlockType(),
1317+
getResult().getType(), getXOffsets().getType());
1318+
}
1319+
1320+
// -- DescriptorScatterOp --
1321+
LogicalResult DescriptorScatterOp::verify() {
1322+
return verifyGatherScatterOp(*this, getDesc().getType().getBlockType(),
1323+
getSrc().getType(), getXOffsets().getType());
1324+
}
1325+
13111326
// -- DescriptorLoadOp --
13121327
static LogicalResult verifyDescriptorLoadStoreType(Operation *op,
13131328
TensorDescType desc,

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,8 @@ LogicalResult AsyncTMAGatherOp::verify() {
244244
triton::gpu::MemDescType resultType = getResult().getType();
245245
if (!resultType.getMutableMemory())
246246
return emitOpError("cannot store into immutable memory");
247-
return DescriptorGatherOp::verifyResultType(*this, resultType);
247+
return DescriptorGatherOp::verifyResultType(*this, resultType,
248+
getXOffsets().getType());
248249
}
249250

250251
void AsyncTMAGatherOp::getEffects(
@@ -259,6 +260,11 @@ void AsyncTMAGatherOp::getEffects(
259260
}
260261

261262
// -- AsyncTMAScatter --
263+
LogicalResult AsyncTMAScatterOp::verify() {
264+
return DescriptorGatherOp::verifyResultType(*this, getSrc().getType(),
265+
getXOffsets().getType());
266+
}
267+
262268
void AsyncTMAScatterOp::getEffects(
263269
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
264270
&effects) {

test/Analysis/test-membar.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -787,12 +787,14 @@ tt.func @tma_special_cases(%arg1: !tt.ptr<i8, 0>) -> (tensor<256x64xf16, #blocke
787787
ttng.async_tma_copy_global_to_local %arg1[%c0, %c0] %alloc, %barrier, %true : !tt.ptr<i8, 0>, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>
788788
ttng.wait_barrier %barrier, %c0 : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable>
789789

790+
// CHECK-NEXT: memdesc_subview
790791
// CHECK-NEXT: ttng.barrier_expect
791792
// CHECK-NEXT: ttng.async_tma_gather
792793
// CHECK-NEXT: gpu.barrier
793794
// CHECK-NEXT: ttng.wait_barrier
795+
%view = ttg.memdesc_subview %alloc[%c0, %c0] : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<32x64xf16, #shared, #ttg.shared_memory, mutable>
794796
ttng.barrier_expect %barrier, 49152, %true : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable>
795-
ttng.async_tma_gather %arg1[%cx, %c0] %alloc, %barrier, %true : !tt.ptr<i8, 0>, tensor<32xi32>, i32, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, i1
797+
ttng.async_tma_gather %arg1[%cx, %c0] %view, %barrier, %true : !tt.ptr<i8, 0>, tensor<32xi32>, i32, !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<32x64xf16, #shared, #ttg.shared_memory, mutable>, i1
796798
ttng.wait_barrier %barrier, %c0 : !ttg.memdesc<1xi64, #shared1, #ttg.shared_memory, mutable>
797799

798800
// CHECK-NEXT: gpu.barrier

0 commit comments

Comments
 (0)