Skip to content

Commit e8c4711

Browse files
authored
[NVIDIA] Add is_async flag to MMAv5 ops (#7590)
Based on the discussion in triton-lang/triton#7581 (comment), ~the MMAv5 ops are now async by default at the IR def level, more faithfully modeling the corresponding ptx instructions~. The new flag determines the sync or async nature of the ops, rather than the presence of the "completion barrier".
1 parent ab4a29a commit e8c4711

File tree

19 files changed

+99
-60
lines changed

19 files changed

+99
-60
lines changed

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ def MMAv5OpInterface : OpInterface<"MMAv5OpInterface"> {
4949
InterfaceMethod<"Get the produced write dependency of the accumulator.",
5050
"::mlir::Value",
5151
"getToken">,
52+
InterfaceMethod<"Indicate that this MMA op executes asynchronously.",
53+
"void",
54+
"setIsAsync",
55+
(ins "bool":$isAsync)>,
5256
];
5357
}
5458
#endif // TRITON_NVIDIAGPU_OP_INTERFACES

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -421,8 +421,8 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [
421421

422422
let description = [{
423423
$d += matrix_multiply($a, $b).
424-
If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
425-
If there is a barrier the result will be safe to read after a barrier wait.
424+
if is_async is false, the op executes synchronously. The barrier operands must not be present in that case.
425+
Otherwise, if a barrier is given, the op will trigger a commit/arrive on it. The result will be safe to read after a barrier wait.
426426
If $two_ctas is set the op will execute a matmul across two contiguous CTAs, it will read the data distributed across the two CTAs.
427427
and syncronize both CTAs if the op is synchronous.
428428

@@ -440,7 +440,8 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [
440440
I1:$pred,
441441
Variadic<TTG_MemDescType>:$barriers,
442442
Variadic<I1>:$barrier_preds,
443-
OptionalAttr<UnitAttr>:$two_ctas
443+
UnitAttr:$is_async,
444+
UnitAttr:$two_ctas
444445
);
445446
let results = (outs Optional<TTG_AsyncToken>:$token);
446447

@@ -449,7 +450,8 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [
449450
"Value":$a, "Value":$b, "Value":$d, "Value":$acc_dep, "Value":$useD,
450451
"Value":$pred, CArg<"bool", "false">:$two_ctas,
451452
CArg<"ValueRange", "{}">:$barriers,
452-
CArg<"ValueRange", "{}">:$barrier_preds)>
453+
CArg<"ValueRange", "{}">:$barrier_preds,
454+
CArg<"bool", "false">:$is_async)>
453455
];
454456

455457
let assemblyFormat = [{
@@ -458,6 +460,8 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [
458460
attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `,`
459461
qualified(type($d)) (`,` qualified(type($barriers))^)?
460462
}];
463+
464+
let hasVerifier = 1;
461465
}
462466

463467
def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
@@ -470,8 +474,9 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
470474

471475
let description = [{
472476
$d += matrix_multiply(scale($lhs, $lhs_scale), scale(rlhs, $rhs_scale))
473-
If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
474-
If there is a barrier the result will be safe to read after a barrier wait.
477+
if is_async is false, the op executes synchronously. The barrier operands must not be present in that case.
478+
Otherwise, if a barrier is given, the op will trigger a commit/arrive on it.
479+
The result will be safe to read after a barrier wait.
475480

476481
This operation takes and produces an optional token to indicate TMEM read
477482
and write on its accumulator operand. When the tokens are present, they can
@@ -490,7 +495,8 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
490495
I1:$useD,
491496
I1:$pred,
492497
Variadic<TTG_MemDescType>:$barriers,
493-
Variadic<I1>:$barrier_preds
498+
Variadic<I1>:$barrier_preds,
499+
UnitAttr:$is_async
494500
);
495501
let results = (outs Optional<TTG_AsyncToken>:$token);
496502

@@ -510,7 +516,8 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
510516
"::mlir::triton::ScaleDotElemType":$b_type,
511517
"::mlir::Value":$useD, "::mlir::Value":$pred,
512518
CArg<"::mlir::ValueRange", "{}">:$barriers,
513-
CArg<"::mlir::ValueRange", "{}">:$barrier_preds)>
519+
CArg<"::mlir::ValueRange", "{}">:$barrier_preds,
520+
CArg<"bool", "false">:$is_async)>
514521
];
515522

516523
let assemblyFormat = [{
@@ -521,6 +528,8 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
521528
qualified(type($d)) `,` qualified(type($a_scale)) `,`
522529
qualified(type($b_scale)) (`,` qualified(type($barriers))^)?
523530
}];
531+
532+
let hasVerifier = 1;
524533
}
525534

526535
def TTNG_TCGen5CommitOp : TTNG_Op<"tc_gen5_commit"> {

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,7 @@ void createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
814814
triton::createSingleBufferView(builder, barrierAlloc, barrierIdx);
815815
}
816816
mma.addCompletionBarrier(barrierSlice, vTrue);
817+
mma.setIsAsync(true);
817818

818819
// List of buffers that may be used until wait completes
819820
SmallVector<Value> waitBuffers;

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ LogicalResult PipelinedLoadGroup::lowerLoads(WarpSchedule &schedule,
422422
for (Operation *asyncUser : distinctAsyncUsers) {
423423
if (auto mmaOp = dyn_cast<ttng::MMAv5OpInterface>(asyncUser)) {
424424
mmaOp.addCompletionBarrier(curEmptyBar, b.boolCst(true));
425+
mmaOp.setIsAsync(true);
425426
continue;
426427
}
427428
llvm::report_fatal_error("FIXME: unhandled async user of pipelined load: " +
@@ -764,6 +765,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
764765
b.setInsertionPoint(mmaOp);
765766
Value bar = createSingleBufferView(b, node.barNext, node.index);
766767
mmaOp.addCompletionBarrier(bar, userPred);
768+
mmaOp.setIsAsync(true);
767769
} else {
768770
b.setInsertionPointAfter(lastOp);
769771
if (isa<scf::IfOp>(lastOp->getParentOp()) && accIsMultiBuffered)
@@ -802,6 +804,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
802804
b.createInto<ttng::WaitBarrierOp>(*schedule.getPartition(mmaOp),
803805
getStageCluster(mmaOp), readyBar, phase);
804806
mmaOp.addCompletionBarrier(emptyBar, b.boolCst(true));
807+
mmaOp.setIsAsync(true);
805808
}
806809

807810
if (nodes.back().barNext) {

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,13 @@ static void printToken(OpAsmPrinter &p, Operation *op, Value dep, Type token) {
243243
p << ']';
244244
}
245245

246+
LogicalResult TCGen5MMAOp::verify() {
247+
if (!getIsAsync() && !getBarriers().empty()) {
248+
return emitOpError("The op is synchronous but a barrier is present.");
249+
}
250+
return success();
251+
}
252+
246253
void TCGen5MMAOp::getEffects(
247254
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
248255
&effects) {
@@ -296,12 +303,23 @@ void TCGen5MMAOp::setPredicate(Value pred) { getPredMutable().assign(pred); }
296303
void TCGen5MMAOp::build(OpBuilder &builder, OperationState &state, Type token,
297304
Value a, Value b, Value d, Value accDep, Value useD,
298305
Value pred, bool useTwoCTAs, ValueRange barriers,
299-
ValueRange barrierPreds) {
306+
ValueRange barrierPreds, bool isAsync) {
307+
if (!barriers.empty()) {
308+
isAsync = true;
309+
}
300310
build(builder, state, token, a, b, d, accDep, useD, pred, barriers,
301-
barrierPreds, useTwoCTAs ? builder.getUnitAttr() : UnitAttr());
311+
barrierPreds, isAsync ? builder.getUnitAttr() : UnitAttr(),
312+
useTwoCTAs ? builder.getUnitAttr() : UnitAttr());
302313
}
303314

304315
// -- TCGen5MMAScaledOp --
316+
LogicalResult TCGen5MMAScaledOp::verify() {
317+
if (!getIsAsync() && !getBarriers().empty()) {
318+
return emitOpError("The op is synchronous but a barrier is present.");
319+
}
320+
return success();
321+
}
322+
305323
void TCGen5MMAScaledOp::getEffects(
306324
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
307325
&effects) {
@@ -450,12 +468,15 @@ void TCGen5MMAScaledOp::build(OpBuilder &builder, OperationState &state,
450468
Value accDep, Value aScale, Value bScale,
451469
ScaleDotElemType aType, ScaleDotElemType bType,
452470
Value useD, Value pred, ValueRange barriers,
453-
ValueRange barrierPreds) {
471+
ValueRange barrierPreds, bool isAsync) {
454472
MLIRContext *ctx = builder.getContext();
473+
if (!barriers.empty()) {
474+
isAsync = true;
475+
}
455476
build(builder, state, token, a, b, d, accDep, aScale, bScale,
456477
ScaleDotElemTypeAttr::get(ctx, aType),
457478
ScaleDotElemTypeAttr::get(ctx, bType), useD, pred, barriers,
458-
barrierPreds);
479+
barrierPreds, isAsync ? builder.getUnitAttr() : UnitAttr());
459480
}
460481

461482
// -- TMEMStoreOp --

lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class SyncMMALowering : public OpRewritePattern<TCGen5MMAOpTy> {
2424
LogicalResult matchAndRewrite(TCGen5MMAOpTy op,
2525
PatternRewriter &rewriter) const override {
2626
// If the op doesn't have synchronous semantic skip the pattern.
27-
if (!op.getBarriers().empty())
27+
if (op.getIsAsync())
2828
return failure();
2929
MLIRContext *ctx = op.getContext();
3030
Location loc = op.getLoc();
@@ -42,6 +42,7 @@ class SyncMMALowering : public OpRewritePattern<TCGen5MMAOpTy> {
4242
rewriter.create<InitBarrierOp>(loc, barrierAlloc, 1);
4343
op.addCompletionBarrier(barrierAlloc,
4444
rewriter.create<arith::ConstantIntOp>(loc, 1, 1));
45+
op.setIsAsync(true);
4546

4647
rewriter.setInsertionPointAfter(op);
4748
Value phase = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);

python/test/gluon/test_frontend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def test_tcgen05_mma_mbar(fresh_knobs):
567567
%true = arith.constant true loc(#loc)
568568
%true_0 = arith.constant true loc(#loc)
569569
%true_1 = arith.constant true loc(#loc)
570-
%3 = ttng.tc_gen5_mma %0, %1, %result[], %true, %true_0, %2[%true_1] : !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
570+
%3 = ttng.tc_gen5_mma %0, %1, %result[], %true, %true_0, %2[%true_1] {is_async} : !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
571571
tt.return loc(#loc)
572572
} loc(#loc)
573573
} loc(#loc)

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
2323
%pred: i1,
2424
%barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>,
2525
%barrierPred: i1) {
26-
ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier[%barrierPred] :
26+
ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier[%barrierPred] {is_async} :
2727
!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>,
2828
!ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory>,
2929
!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>,
@@ -56,7 +56,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
5656
%pred: i1,
5757
%barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>,
5858
%barrierPred: i1) {
59-
ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier[%barrierPred] :
59+
ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier[%barrierPred] {is_async} :
6060
!ttg.memdesc<128x16xf16, #shared, #ttg.shared_memory>,
6161
!ttg.memdesc<16x128xf16, #shared1, #ttg.shared_memory>,
6262
!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>,
@@ -89,7 +89,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
8989
%pred: i1,
9090
%barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>,
9191
%barrierPred: i1) {
92-
ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier[%barrierPred] :
92+
ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier[%barrierPred] {is_async} :
9393
!ttg.memdesc<128x16xf16, #shared, #ttg.shared_memory>,
9494
!ttg.memdesc<16x128xf16, #shared1, #ttg.shared_memory>,
9595
!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>,
@@ -219,7 +219,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
219219
%pred: i1,
220220
%barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>,
221221
%barrierPred: i1) {
222-
ttng.tc_gen5_mma_scaled %a, %b, %c, %scale_a, %scale_b, %useAcc, %pred lhs = e4m3 rhs = e2m1, %barrier[%barrierPred] :
222+
ttng.tc_gen5_mma_scaled %a, %b, %c, %scale_a, %scale_b, %useAcc, %pred lhs = e4m3 rhs = e2m1, %barrier[%barrierPred] {is_async} :
223223
!ttg.memdesc<128x64xi8, #shared, #ttg.shared_memory>,
224224
!ttg.memdesc<32x128xi8, #shared1, #ttg.shared_memory>,
225225
!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>,
@@ -256,7 +256,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
256256
%pred: i1,
257257
%barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>,
258258
%barrierPred: i1) {
259-
ttng.tc_gen5_mma_scaled %a, %b, %c, %scale_a, %scale_b, %useAcc, %pred lhs = e2m1 rhs = e4m3, %barrier[%barrierPred] :
259+
ttng.tc_gen5_mma_scaled %a, %b, %c, %scale_a, %scale_b, %useAcc, %pred lhs = e2m1 rhs = e4m3, %barrier[%barrierPred] {is_async} :
260260
!ttg.memdesc<128x64xi8, #shared1, #ttg.shared_memory>,
261261
!ttg.memdesc<128x128xi8, #shared, #ttg.shared_memory>,
262262
!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>,
@@ -285,7 +285,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
285285
// CHECK: tcgen05.mma.cta_group::2.kind::f16
286286
// CHECK: tcgen05.mma.cta_group::2.kind::f16
287287
// CHECK: tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64
288-
ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier[%barrierPred] {two_ctas} :
288+
ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier[%barrierPred] {is_async, two_ctas} :
289289
!ttg.memdesc<256x32xf16, #shared, #ttg.shared_memory>,
290290
!ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory>,
291291
!ttg.memdesc<256x128xf32, #tmem, #ttng.tensor_memory, mutable>,
@@ -334,7 +334,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
334334
%pred: i1,
335335
%barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>,
336336
%barrierPred: i1) {
337-
ttng.tc_gen5_mma_scaled %a, %b, %c, %scale_a, %scale_b, %useAcc, %pred lhs = e2m1 rhs = e2m1, %barrier[%barrierPred] :
337+
ttng.tc_gen5_mma_scaled %a, %b, %c, %scale_a, %scale_b, %useAcc, %pred lhs = e2m1 rhs = e2m1, %barrier[%barrierPred] {is_async} :
338338
!ttg.memdesc<128x64xi8, #shared, #ttg.shared_memory>,
339339
!ttg.memdesc<64x256xi8, #shared1, #ttg.shared_memory>,
340340
!ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>,
@@ -368,7 +368,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
368368
%pred: i1,
369369
%barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>,
370370
%barrierPred: i1) {
371-
ttng.tc_gen5_mma_scaled %a, %b, %c, %scale_a, %scale_b, %useAcc, %pred lhs = e2m1 rhs = e2m1, %barrier[%barrierPred] :
371+
ttng.tc_gen5_mma_scaled %a, %b, %c, %scale_a, %scale_b, %useAcc, %pred lhs = e2m1 rhs = e2m1, %barrier[%barrierPred] {is_async} :
372372
!ttg.memdesc<128x64xi8, #shared, #ttg.shared_memory>,
373373
!ttg.memdesc<64x256xi8, #shared1, #ttg.shared_memory>,
374374
!ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable>,
@@ -584,7 +584,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
584584
tt.func @tc_gen5_mma_lhs_tmem(%arg0: !ttg.memdesc<128x32xf16, #tmem, #ttng.tensor_memory>, %arg1: !ttg.memdesc<32x128xf16, #shared, #smem>, %arg2: !ttg.memdesc<128x128xf32, #tmem1, #ttng.tensor_memory, mutable>, %arg3: i1, %arg4: i1, %arg5: !ttg.memdesc<1xi64, #shared1, #smem>, %barrierPred: i1) {
585585
// CHECK-LABEL: tc_gen5_mma_lhs_tmem
586586
// CHECK: tcgen05.mma.cta_group::1.kind::f16
587-
ttng.tc_gen5_mma %arg0, %arg1, %arg2, %arg3, %arg4, %arg5[%barrierPred] :
587+
ttng.tc_gen5_mma %arg0, %arg1, %arg2, %arg3, %arg4, %arg5[%barrierPred] {is_async} :
588588
!ttg.memdesc<128x32xf16, #tmem, #ttng.tensor_memory>,
589589
!ttg.memdesc<32x128xf16, #shared, #smem>,
590590
!ttg.memdesc<128x128xf32, #tmem1, #ttng.tensor_memory, mutable>,

test/NVWS/lower_warp_group.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
2222
%false = arith.constant false
2323
nvws.warp_group
2424
partition0 num_warps(8) {
25-
ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred, %barrier[%false]:
25+
ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred, %barrier[%false] {is_async} :
2626
!ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>,
2727
!ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>,
2828
!ttg.memdesc<128x256xf8E5M2, #shared1, #ttng.tensor_memory, mutable>,
@@ -55,7 +55,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
5555
%false = arith.constant false
5656
nvws.warp_group
5757
partition0 num_warps(4) {
58-
ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred, %barrier[%false]:
58+
ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred, %barrier[%false] {is_async} :
5959
!ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>,
6060
!ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>,
6161
!ttg.memdesc<128x256xf8E5M2, #shared1, #ttng.tensor_memory, mutable>,
@@ -99,7 +99,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
9999
%c0 = arith.constant 0 : i32
100100
nvws.warp_group
101101
partition0 num_warps(4) {
102-
ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred, %barrier[%false]:
102+
ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred, %barrier[%false] {is_async} :
103103
!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>,
104104
!ttg.memdesc<128x256xf16, #shared1, #ttg.shared_memory>,
105105
!ttg.memdesc<128x256xf16, #acc_tmem, #ttng.tensor_memory, mutable>,

test/TritonGPU/consan.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shar
197197
ttng.init_barrier %bar, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable>
198198
%result = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable>
199199
%true = arith.constant true
200-
ttng.tc_gen5_mma %0, %1, %result[], %true, %true, %bar[%true] : !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<1xi64, #shared1, #smem, mutable>
200+
ttng.tc_gen5_mma %0, %1, %result[], %true, %true, %bar[%true] {is_async} : !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<1xi64, #shared1, #smem, mutable>
201201
tt.return
202202
}
203203
}

0 commit comments

Comments
 (0)