Skip to content

Commit e322605

Browse files
authored
[TensorDesc] Cleanup ttng IR representation (#7036)
This moves remaining `triton::ExperimentalFoo` ops to `ttng::Foo`, changes `ttng::AsyncTMA` ops to now take `tensordesc` arguments, and also removes `TensorDescToTMAPtrOp`.
1 parent 2a10b48 commit e322605

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+389
-660
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ dev-install-llvm:
106106

107107
.PHONY: golden-samples
108108
golden-samples: triton-opt
109-
$(TRITON_OPT) test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-pipeline -canonicalize | \
109+
$(TRITON_OPT) test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-assign-latencies -tritongpu-schedule-loops -tritongpu-pipeline -canonicalize | \
110110
$(PYTHON) utils/generate-test-checks.py --source test/TritonGPU/samples/simulated-grouped-gemm.mlir.in --source_delim_regex="\bmodule" \
111111
-o test/TritonGPU/samples/simulated-grouped-gemm.mlir
112-
$(TRITON_OPT) test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-pipeline -canonicalize | \
112+
$(TRITON_OPT) test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-assign-latencies -tritongpu-schedule-loops -tritongpu-pipeline -canonicalize | \
113113
$(PYTHON) utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in --source_delim_regex="\bmodule" \
114114
-o test/TritonGPU/samples/descriptor-matmul-pipeline.mlir

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,22 +1033,6 @@ def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [
10331033
}];
10341034
}
10351035

1036-
def ReinterpretTensorDescOp : TT_Op<"reinterpret_tensor_descriptor", [Pure]> {
1037-
let summary = "Reinterpret a pointer as a tensor descriptor";
1038-
1039-
let description = [{
1040-
This Op exists to help the transition from untyped raw TMA objects to typed Tensor descriptor objects.
1041-
Ideally, we can remove this once the APIs are fully fleshed out.
1042-
}];
1043-
1044-
let arguments = (ins TT_Ptr:$rawDesc);
1045-
let results = (outs TT_TensorDescType:$result);
1046-
1047-
let assemblyFormat = [{
1048-
$rawDesc attr-dict `:` qualified(type($rawDesc)) `to` qualified(type($result))
1049-
}];
1050-
}
1051-
10521036
// The following ops, including `call`, `func`, and `return` are copied and modified from
10531037
// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
10541038
// We could revert it back once MLIR has a better inliner interface.
@@ -1390,54 +1374,5 @@ def TT_DescriptorScatterOp : TT_Op<"descriptor_scatter", [TT_DescriptorStoreLike
13901374
let hasVerifier = 1;
13911375
}
13921376

1393-
def TT_ExperimentalTensormapCreateOp: TT_Op<
1394-
"experimental_tensormap_create",
1395-
[
1396-
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>,
1397-
AttrSizedOperandSegments,
1398-
]
1399-
> {
1400-
let summary = "Create a new TMA descriptor on device";
1401-
let arguments = (
1402-
ins
1403-
TT_PtrType:$desc_ptr,
1404-
TT_PtrType:$global_address,
1405-
Variadic<I32>:$box_dim,
1406-
Variadic<I32>:$global_dim,
1407-
Variadic<I64>:$global_stride,
1408-
Variadic<I32>:$element_stride,
1409-
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<15>]>:$elem_type,
1410-
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<2>]>:$interleave_layout,
1411-
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$swizzle_mode,
1412-
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$fill_mode
1413-
);
1414-
let extraClassDeclaration = [{
1415-
int32_t getRank() {
1416-
return getBoxDim().size();
1417-
}
1418-
}];
1419-
let assemblyFormat = [{
1420-
$desc_ptr `,` $global_address `,`
1421-
`[` $box_dim `]` `,`
1422-
`[` $global_dim `]` `,`
1423-
`[` $global_stride `]` `,`
1424-
`[` $element_stride `]`
1425-
attr-dict `:` functional-type(operands, results)
1426-
}];
1427-
1428-
let hasVerifier = 1;
1429-
}
1430-
1431-
def TT_ExperimentalTensormapFenceproxyAcquireOp: TT_Op<
1432-
"experimental_tensormap_fenceproxy_acquire",
1433-
[MemoryEffects<[MemWrite<GlobalMemory>]>]
1434-
> {
1435-
let summary = "Acquire fence on a tensormap object";
1436-
let arguments = (ins TT_PtrType:$desc_ptr);
1437-
let assemblyFormat = [{
1438-
$desc_ptr attr-dict `:` qualified(type($desc_ptr))
1439-
}];
1440-
}
1441-
14421377

14431378
#endif // Triton_OPS

include/triton/Dialect/Triton/IR/TritonTypes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def TT_TensorPtr : TT_PtrOf<[TT_Tensor]>;
9292
// Any Type in Triton IR
9393
def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike, TT_TensorPtr]>;
9494

95-
// Result type of ExperimentalMakeTensorDescriptor
95+
// Result type of MakeTensorDescriptor
9696
def TT_TensorDescType : TritonTypeDef<"TensorDesc", "tensordesc", []> {
9797
let summary = "Tensor descriptor type (`::mlir::triton::TensorDescType`) in Triton IR type system";
9898

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 80 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -262,26 +262,6 @@ def TTNG_ArriveBarrierOp : TTNG_Op<"arrive_barrier"> {
262262
let hasVerifier = 1;
263263
}
264264

265-
def TTNG_TensorDescToTMAPtrOp : TTNG_Op<"tensor_desc_to_tma_ptr", [Pure]> {
266-
let summary = "Convert tensor descriptor to pointer to tma descriptor";
267-
268-
let arguments = (ins TT_TensorDescType:$desc);
269-
let results = (outs TT_Ptr:$ptr);
270-
271-
let assemblyFormat = [{
272-
$desc attr-dict `:` qualified(type($desc)) `to` qualified(type($ptr))
273-
}];
274-
275-
let builders = [
276-
OpBuilder<(ins "Value":$desc), [{
277-
auto ptrTy = triton::PointerType::get($_builder.getI8Type(), 1);
278-
build($_builder, $_state, ptrTy, desc);
279-
}]>
280-
];
281-
282-
let hasCanonicalizeMethod = 1;
283-
}
284-
285265

286266
def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local"> {
287267
let summary = "copy data based on descriptor from global memory to local memory asynchronously";
@@ -291,12 +271,12 @@ def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local">
291271
asynchronously. This is analogue to tt.load except the data are copied to
292272
local memory pointed by the memory descriptor instead of a distributed
293273
tensor. The data copied depends on the global memory descriptor pointed to
294-
by `desc_ptr`.
274+
by `desc`.
295275
}];
296276

297277
let hasVerifier = 1;
298278
let arguments = (ins
299-
Arg<TT_PtrType, "", [MemRead<GlobalMemory>]>:$desc_ptr,
279+
Arg<TT_TensorDescType, "", [MemRead<GlobalMemory>]>:$desc,
300280
Variadic<I32>:$coord,
301281
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$barrier,
302282
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$result,
@@ -307,9 +287,9 @@ def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local">
307287
);
308288

309289
let assemblyFormat = [{
310-
$desc_ptr `[` $coord `]` $result `,` $barrier `,` $pred
290+
$desc `[` $coord `]` $result `,` $barrier `,` $pred
311291
oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict)
312-
attr-dict `:` qualified(type($desc_ptr)) `,` qualified(type($barrier)) `->` qualified(type($result))
292+
attr-dict `:` qualified(type($desc)) `,` qualified(type($barrier)) `->` qualified(type($result))
313293
}];
314294
}
315295

@@ -321,18 +301,18 @@ def TTNG_AsyncTMACopyLocalToGlobalOp : TTNG_Op<"async_tma_copy_local_to_global">
321301
asynchronously. This is analogue to tt.store except the data are copied from
322302
local memory pointed by the memory descriptor instead of a distributed
323303
tensor. The data copied depends on the global memory descriptor pointed to
324-
by `desc_ptr`.
304+
by `desc`.
325305
}];
326306

327307
let arguments = (ins
328-
Arg<TT_PtrType, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>:$desc_ptr,
308+
Arg<TT_TensorDescType, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>:$desc,
329309
Variadic<I32>:$coord,
330310
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>]>:$src
331311
);
332312

333313
let assemblyFormat = [{
334-
$desc_ptr `[` $coord `]` $src
335-
attr-dict `:` qualified(type($desc_ptr)) `,` qualified(type($src))
314+
$desc `[` $coord `]` $src
315+
attr-dict `:` qualified(type($desc)) `,` qualified(type($src))
336316
}];
337317
}
338318

@@ -348,14 +328,14 @@ def TTNG_AsyncTMAReduceOp : TTNG_Op<"async_tma_reduce", [MemoryEffects<[MemRead<
348328

349329
let arguments = (ins
350330
TT_DescriptorReduceKindAttr:$kind,
351-
Arg<TT_PtrType, "", [MemRead<GlobalMemory>]>:$desc_ptr,
331+
Arg<TT_TensorDescType, "", [MemRead<GlobalMemory>]>:$desc,
352332
Variadic<I32>:$coord,
353333
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>]>:$src
354334
);
355335

356336
let assemblyFormat = [{
357-
$kind `,` $desc_ptr `[` $coord `]` $src
358-
attr-dict `:` qualified(type($desc_ptr)) `,` qualified(type($src))
337+
$kind `,` $desc `[` $coord `]` $src
338+
attr-dict `:` qualified(type($desc)) `,` qualified(type($src))
359339
}];
360340
}
361341

@@ -369,7 +349,7 @@ def TTNG_AsyncTMAGatherOp : TTNG_Op<"async_tma_gather"> {
369349
}];
370350

371351
let arguments = (ins
372-
Arg<TT_PtrType, "", [MemRead<GlobalMemory>]>:$desc_ptr,
352+
Arg<TT_TensorDescType, "", [MemRead<GlobalMemory>]>:$desc,
373353
RankedTensorOf<[I32]>:$x_offsets,
374354
I32:$y_offset,
375355
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$barrier,
@@ -378,7 +358,7 @@ def TTNG_AsyncTMAGatherOp : TTNG_Op<"async_tma_gather"> {
378358
);
379359

380360
let assemblyFormat = [{
381-
$desc_ptr `[` $x_offsets `,` $y_offset `]` $result `,` $barrier `,` $pred
361+
$desc `[` $x_offsets `,` $y_offset `]` $result `,` $barrier `,` $pred
382362
attr-dict `:` type(operands)
383363
}];
384364

@@ -397,14 +377,14 @@ def TTNG_AsyncTMAScatterOp : TTNG_Op<"async_tma_scatter"> {
397377
}];
398378

399379
let arguments = (ins
400-
Arg<TT_PtrType, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>:$desc_ptr,
380+
Arg<TT_TensorDescType, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>:$desc,
401381
RankedTensorOf<[I32]>:$x_offsets,
402382
I32:$y_offset,
403383
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>]>:$src
404384
);
405385

406386
let assemblyFormat = [{
407-
$desc_ptr `[` $x_offsets `,` $y_offset `]` $src
387+
$desc `[` $x_offsets `,` $y_offset `]` $src
408388
attr-dict `:` type(operands)
409389
}];
410390

@@ -700,4 +680,69 @@ def TTNG_TMEMCopyOp : TTNG_Op<"tmem_copy"> {
700680
let hasVerifier = 1;
701681
}
702682

683+
def TTNG_ReinterpretTensorDescOp : TTNG_Op<"reinterpret_tensor_descriptor", [Pure]> {
684+
let summary = "Reinterpret a pointer as a tensor descriptor";
685+
686+
let description = [{
687+
This Op exists to help the transition from untyped raw TMA objects to typed Tensor descriptor objects.
688+
Ideally, we can remove this once the APIs are fully fleshed out.
689+
}];
690+
691+
let arguments = (ins TT_Ptr:$rawDesc);
692+
let results = (outs TT_TensorDescType:$result);
693+
694+
let assemblyFormat = [{
695+
$rawDesc attr-dict `:` qualified(type($rawDesc)) `to` qualified(type($result))
696+
}];
697+
}
698+
699+
def TTNG_TensormapCreateOp: TTNG_Op<
700+
"tensormap_create",
701+
[
702+
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>,
703+
AttrSizedOperandSegments,
704+
]
705+
> {
706+
let summary = "Create a new TMA descriptor on device";
707+
let arguments = (
708+
ins
709+
TT_PtrType:$desc_ptr,
710+
TT_PtrType:$global_address,
711+
Variadic<I32>:$box_dim,
712+
Variadic<I32>:$global_dim,
713+
Variadic<I64>:$global_stride,
714+
Variadic<I32>:$element_stride,
715+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<15>]>:$elem_type,
716+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<2>]>:$interleave_layout,
717+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$swizzle_mode,
718+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$fill_mode
719+
);
720+
let extraClassDeclaration = [{
721+
int32_t getRank() {
722+
return getBoxDim().size();
723+
}
724+
}];
725+
let assemblyFormat = [{
726+
$desc_ptr `,` $global_address `,`
727+
`[` $box_dim `]` `,`
728+
`[` $global_dim `]` `,`
729+
`[` $global_stride `]` `,`
730+
`[` $element_stride `]`
731+
attr-dict `:` functional-type(operands, results)
732+
}];
733+
734+
let hasVerifier = 1;
735+
}
736+
737+
def TTNG_TensormapFenceproxyAcquireOp: TTNG_Op<
738+
"tensormap_fenceproxy_acquire",
739+
[MemoryEffects<[MemWrite<GlobalMemory>]>]
740+
> {
741+
let summary = "Acquire fence on a tensormap object";
742+
let arguments = (ins TT_PtrType:$desc_ptr);
743+
let assemblyFormat = [{
744+
$desc_ptr attr-dict `:` qualified(type($desc_ptr))
745+
}];
746+
}
747+
703748
#endif

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def TritonNvidiaGPUTMALoweringPass : Pass<"triton-nvidia-tma-lowering", "mlir::M
6464
let summary = "lower to TMA load/store operations";
6565

6666
let description = [{
67-
Lower Triton experimental descriptor load to TMA load/store operations in TritonNvidiaGPUDialect.
67+
Lower Triton descriptor load to TMA load/store operations in TritonNvidiaGPUDialect.
6868
}];
6969

7070
let dependentDialects = [

lib/Analysis/Allocation.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "triton/Dialect/Triton/IR/Dialect.h"
1010
#include "triton/Dialect/Triton/IR/Utility.h"
1111
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
12+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1213
#include "llvm/ADT/SmallVector.h"
1314
#include "llvm/Support/Debug.h"
1415
#include "llvm/Support/raw_ostream.h"
@@ -17,6 +18,8 @@
1718
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
1819
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
1920

21+
namespace ttng = mlir::triton::nvidia_gpu;
22+
2023
namespace mlir {
2124

2225
//===----------------------------------------------------------------------===//
@@ -206,7 +209,7 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
206209
assert(!isa<PointerType>(elemTy) && "unexpected pointer type");
207210
return elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
208211
}
209-
if (isa<ExperimentalTensormapCreateOp>(op)) {
212+
if (isa<ttng::TensormapCreateOp>(op)) {
210213
constexpr int32_t kTMASize = 128;
211214
return kTMASize;
212215
}

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -568,8 +568,6 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
568568
GenericOpPattern<triton::DescriptorLoadOp>,
569569
GenericOpPattern<triton::DescriptorStoreOp>,
570570
GenericOpPattern<triton::DescriptorReduceOp>,
571-
GenericOpPattern<triton::ExperimentalTensormapCreateOp>,
572-
GenericOpPattern<triton::ExperimentalTensormapFenceproxyAcquireOp>,
573571
// this assumes the right layout will be set later for dot scaled.
574572
GenericOpPattern<triton::DotScaledOp>,
575573
GenericOpPattern<triton::CallOp>,

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,23 +1364,5 @@ LogicalResult DescriptorStoreOp::verify() {
13641364
getSrc().getType());
13651365
}
13661366

1367-
// -- ExperimentalTensormapCreateOp --
1368-
LogicalResult ExperimentalTensormapCreateOp::verify() {
1369-
auto rank = getBoxDim().size();
1370-
if (getGlobalDim().size() != rank) {
1371-
return emitError("Rank mismatch for global dim. Got ")
1372-
<< getGlobalDim().size() << " but expected " << rank;
1373-
}
1374-
if (getGlobalStride().size() + 1 != rank) {
1375-
return emitError("Rank mismatch for global stride. Got ")
1376-
<< getGlobalStride().size() << " but expected " << rank - 1;
1377-
}
1378-
if (getElementStride().size() != rank) {
1379-
return emitError("Rank mismatch for element stride. Got ")
1380-
<< getElementStride().size() << " but expected " << rank;
1381-
}
1382-
return success();
1383-
}
1384-
13851367
} // namespace triton
13861368
} // namespace mlir

lib/Dialect/TritonGPU/Transforms/Canonicalize.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ void Canonicalize::runOnOperation() {
4545
BroadcastOp::getCanonicalizationPatterns(patterns, ctx);
4646
ExpandDimsOp::getCanonicalizationPatterns(patterns, ctx);
4747
ttg::WarpSpecializeOp::getCanonicalizationPatterns(patterns, ctx);
48-
ttng::TensorDescToTMAPtrOp::getCanonicalizationPatterns(patterns, ctx);
4948

5049
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
5150
}

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,7 @@ void createTMAAsyncCopy(
312312
Value view = createSingleBufferView(builder, alloc, insertIdx);
313313

314314
Value pred = builder.create<arith::ConstantIntOp>(1, 1);
315-
Value tmaPtr = builder.create<triton::nvidia_gpu::TensorDescToTMAPtrOp>(desc);
316-
createCopy(builder, tmaPtr, barrier, view, pred);
315+
createCopy(builder, desc, barrier, view, pred);
317316

318317
// Create local load after the wait
319318
builder.setInsertionPointAfter(waitOp);
@@ -697,8 +696,8 @@ LogicalResult rewriteTMABufferUpdates(
697696
if (failed(ttng::createTMADesc(nextBuf, makeDescOp, builder))) {
698697
return failure();
699698
}
700-
builder.create<triton::ExperimentalTensormapFenceproxyAcquireOp>(nextBuf);
701-
Value nextDesc = builder.create<triton::ReinterpretTensorDescOp>(
699+
builder.create<ttng::TensormapFenceproxyAcquireOp>(nextBuf);
700+
Value nextDesc = builder.create<ttng::ReinterpretTensorDescOp>(
702701
makeDescOp.getType(), nextBuf);
703702

704703
makeDescOp.getResult().replaceAllUsesWith(nextDesc);

0 commit comments

Comments
 (0)