@@ -26,12 +26,12 @@ Value createFullLike(OpBuilder &builder, Location loc, Value scalar,
26
26
return builder.create <triton::SplatOp>(loc, tensorTy, scalar);
27
27
}
28
28
29
- Value createCmpIntTensorScalar (OpBuilder &builder, Location loc, Value tensor,
30
- Value scalar) {
29
+ Value createCmpIntTensorScalar (
30
+ OpBuilder &builder, Location loc, Value tensor, Value scalar,
31
+ arith::CmpIPredicate predicate = arith::CmpIPredicate::eq) {
31
32
auto tensorTy = cast<RankedTensorType>(tensor.getType ());
32
33
auto splat = createFullLike (builder, loc, scalar, tensorTy);
33
- auto cmp = builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
34
- tensor, splat);
34
+ auto cmp = builder.create <arith::CmpIOp>(loc, predicate, tensor, splat);
35
35
return cmp;
36
36
}
37
37
@@ -512,6 +512,186 @@ struct ClearReadBarrierOpConversion
512
512
}
513
513
};
514
514
515
+ struct StageWriteForCommitOpConversion
516
+ : public ConvertOpToLLVMPattern<tti::ExperimentalStageWriteForCommitOp> {
517
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
518
+
519
+ LogicalResult matchAndRewrite (tti::ExperimentalStageWriteForCommitOp op,
520
+ OpAdaptor adaptor,
521
+ ConversionPatternRewriter &b) const override {
522
+ Location loc = op.getLoc ();
523
+ OpBuilder::InsertionGuard guard (b);
524
+ b.setInsertionPoint (op);
525
+ if (op.getPred ()) {
526
+ auto [prevBlock, ifBlock, thenBlock] =
527
+ createIfBlock (b, loc, op.getPred ());
528
+ b.setInsertionPointToStart (ifBlock);
529
+ }
530
+ TypedValue<RankedTensorType> buffers = op.getBuffers ();
531
+ RankedTensorType writeCommitsType =
532
+ cast<RankedTensorType>(op.getWriteCommitsType ());
533
+ Value writeCommits = tti::createLoadScratchMemory (
534
+ b, loc, op.getWriteCommits (), writeCommitsType)
535
+ ->getResult (0 );
536
+ Value buf = createMemDescToI64 (b, loc, getTypeConverter (),
537
+ op.getBuf ().getType (), adaptor.getBuf ());
538
+
539
+ // Gluon pseudo-code:
540
+ // write_commits = tl.where(bufs == buf, -1, write_commits)
541
+
542
+ auto buffersEqBuf = createCmpIntTensorScalar (b, loc, buffers, buf);
543
+ auto writeCommitsMinusOne =
544
+ tti::createConstIntTensor (b, loc, -1 , writeCommitsType);
545
+ writeCommits = b.create <arith::SelectOp>(
546
+ loc, buffersEqBuf, writeCommitsMinusOne, writeCommits);
547
+ tti::createStoreScratchMemory (b, loc, op.getWriteCommits (), writeCommits,
548
+ writeCommitsType);
549
+ b.eraseOp (op);
550
+ return success ();
551
+ }
552
+ };
553
+
554
+ struct CommitWritesOpConversion
555
+ : public ConvertOpToLLVMPattern<tti::ExperimentalCommitWritesOp> {
556
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
557
+
558
+ LogicalResult matchAndRewrite (tti::ExperimentalCommitWritesOp op,
559
+ OpAdaptor adaptor,
560
+ ConversionPatternRewriter &b) const override {
561
+ Location loc = op.getLoc ();
562
+ OpBuilder::InsertionGuard guard (b);
563
+ b.setInsertionPoint (op);
564
+ if (op.getPred ()) {
565
+ auto [prevBlock, ifBlock, thenBlock] =
566
+ createIfBlock (b, loc, op.getPred ());
567
+ b.setInsertionPointToStart (ifBlock);
568
+ }
569
+ RankedTensorType writeCommitsType =
570
+ cast<RankedTensorType>(op.getWriteCommitsType ());
571
+ Value writeCommits = tti::createLoadScratchMemory (
572
+ b, loc, op.getWriteCommits (), writeCommitsType)
573
+ ->getResult (0 );
574
+
575
+ // Gluon pseudo-code:
576
+ // write_commits = tl.where(write_commits > 0, write_commits + 1,
577
+ // write_commits) write_commits = tl.where(write_commits == -1, 1,
578
+ // write_commits)
579
+
580
+ Type elementType = writeCommitsType.getElementType ();
581
+ Value minusOne = b.create <arith::ConstantOp>(
582
+ loc, elementType, b.getIntegerAttr (elementType, -1 ));
583
+ Value zero = b.create <arith::ConstantOp>(loc, elementType,
584
+ b.getIntegerAttr (elementType, 0 ));
585
+ Value writeCommitsOne =
586
+ tti::createConstIntTensor (b, loc, 1 , writeCommitsType);
587
+
588
+ auto writeCommitsGtZero = createCmpIntTensorScalar (
589
+ b, loc, writeCommits, zero, arith::CmpIPredicate::sgt);
590
+ auto writeCommitsPlusOne =
591
+ b.create <arith::AddIOp>(loc, writeCommits, writeCommitsOne);
592
+ writeCommits = b.create <arith::SelectOp>(loc, writeCommitsGtZero,
593
+ writeCommitsPlusOne, writeCommits);
594
+
595
+ auto writeCommitsEqMinusOne =
596
+ createCmpIntTensorScalar (b, loc, writeCommits, minusOne);
597
+ writeCommits = b.create <arith::SelectOp>(loc, writeCommitsEqMinusOne,
598
+ writeCommitsOne, writeCommits);
599
+ tti::createStoreScratchMemory (b, loc, op.getWriteCommits (), writeCommits,
600
+ writeCommitsType);
601
+ b.eraseOp (op);
602
+ return success ();
603
+ }
604
+ };
605
+
606
+ struct ClearWriteCommitsOpConversion
607
+ : public ConvertOpToLLVMPattern<tti::ExperimentalClearWriteCommitsOp> {
608
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
609
+
610
+ LogicalResult matchAndRewrite (tti::ExperimentalClearWriteCommitsOp op,
611
+ OpAdaptor adaptor,
612
+ ConversionPatternRewriter &b) const override {
613
+ Location loc = op.getLoc ();
614
+ OpBuilder::InsertionGuard guard (b);
615
+ b.setInsertionPoint (op);
616
+ if (op.getPred ()) {
617
+ auto [prevBlock, ifBlock, thenBlock] =
618
+ createIfBlock (b, loc, op.getPred ());
619
+ b.setInsertionPointToStart (ifBlock);
620
+ }
621
+ RankedTensorType writeCommitsType =
622
+ cast<RankedTensorType>(op.getWriteCommitsType ());
623
+ Value writeCommits = tti::createLoadScratchMemory (
624
+ b, loc, op.getWriteCommits (), writeCommitsType)
625
+ ->getResult (0 );
626
+
627
+ // Gluon pseudo-code:
628
+ // write_commits = tl.where(write_commits > outstanding_num, 0,
629
+ // write_commits)
630
+
631
+ Type elementType = writeCommitsType.getElementType ();
632
+ Value outstandingNum = b.create <arith::ConstantOp>(
633
+ loc, elementType,
634
+ b.getIntegerAttr (elementType, op.getOutstandingNum ()));
635
+ Value writeCommitsZero =
636
+ tti::createConstIntTensor (b, loc, 0 , writeCommitsType);
637
+ auto writeCommitsGtOutstandingNum = createCmpIntTensorScalar (
638
+ b, loc, writeCommits, outstandingNum, arith::CmpIPredicate::sgt);
639
+ writeCommits = b.create <arith::SelectOp>(loc, writeCommitsGtOutstandingNum,
640
+ writeCommitsZero, writeCommits);
641
+ tti::createStoreScratchMemory (b, loc, op.getWriteCommits (), writeCommits,
642
+ writeCommitsType);
643
+ b.eraseOp (op);
644
+ return success ();
645
+ }
646
+ };
647
+
648
+ struct CheckWriteCommitOpConversion
649
+ : public ConvertOpToLLVMPattern<tti::ExperimentalCheckWriteCommitOp> {
650
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
651
+
652
+ LogicalResult matchAndRewrite (tti::ExperimentalCheckWriteCommitOp op,
653
+ OpAdaptor adaptor,
654
+ ConversionPatternRewriter &b) const override {
655
+ Location loc = op.getLoc ();
656
+ OpBuilder::InsertionGuard guard (b);
657
+ b.setInsertionPoint (op);
658
+ if (op.getPred ()) {
659
+ auto [prevBlock, ifBlock, thenBlock] =
660
+ createIfBlock (b, loc, op.getPred ());
661
+ b.setInsertionPointToStart (ifBlock);
662
+ }
663
+ TypedValue<RankedTensorType> buffers = op.getBuffers ();
664
+ RankedTensorType writeCommitsType =
665
+ cast<RankedTensorType>(op.getWriteCommitsType ());
666
+ Value writeCommits = tti::createLoadScratchMemory (
667
+ b, loc, op.getWriteCommits (), writeCommitsType)
668
+ ->getResult (0 );
669
+ Value buf = createMemDescToI64 (b, loc, getTypeConverter (),
670
+ op.getBuf ().getType (), adaptor.getBuf ());
671
+
672
+ // Gluon pseudo-code:
673
+ // curr_commits = tl.where(buf == buffers, write_commits, 0)
674
+ // tl.device_assert(curr_commits == 0, "Buffer being accessed has
675
+ // outstanding writes")
676
+
677
+ Type elementType = writeCommitsType.getElementType ();
678
+ auto buffersEqBuf = createCmpIntTensorScalar (b, loc, buffers, buf);
679
+ auto zero = b.create <arith::ConstantOp>(loc, elementType,
680
+ b.getIntegerAttr (elementType, 0 ));
681
+ auto writeCommitsZero =
682
+ tti::createConstIntTensor (b, loc, 0 , writeCommitsType);
683
+ auto currCommits = b.create <arith::SelectOp>(
684
+ loc, buffersEqBuf, writeCommits, writeCommitsZero);
685
+ auto currCommitsEqZero =
686
+ createCmpIntTensorScalar (b, loc, currCommits, zero);
687
+ b.create <tti::ExperimentalAssertInThreadOp>(
688
+ loc, currCommitsEqZero,
689
+ b.getStringAttr (" Buffer being accessed has outstanding writes" ), false );
690
+ b.eraseOp (op);
691
+ return success ();
692
+ }
693
+ };
694
+
515
695
} // namespace
516
696
517
697
void mlir::triton::populateInstrumentationToLLVMPatterns (
@@ -525,4 +705,8 @@ void mlir::triton::populateInstrumentationToLLVMPatterns(
525
705
patterns.add <MarkAsReadOpConversion>(typeConverter);
526
706
patterns.add <ClearWriteBarrierOpConversion>(typeConverter);
527
707
patterns.add <ClearReadBarrierOpConversion>(typeConverter);
708
+ patterns.add <StageWriteForCommitOpConversion>(typeConverter);
709
+ patterns.add <CommitWritesOpConversion>(typeConverter);
710
+ patterns.add <ClearWriteCommitsOpConversion>(typeConverter);
711
+ patterns.add <CheckWriteCommitOpConversion>(typeConverter);
528
712
}
0 commit comments