@@ -26,12 +26,12 @@ Value createFullLike(OpBuilder &builder, Location loc, Value scalar,
2626 return builder.create <triton::SplatOp>(loc, tensorTy, scalar);
2727}
2828
29- Value createCmpIntTensorScalar (OpBuilder &builder, Location loc, Value tensor,
30- Value scalar) {
29+ Value createCmpIntTensorScalar (
30+ OpBuilder &builder, Location loc, Value tensor, Value scalar,
31+ arith::CmpIPredicate predicate = arith::CmpIPredicate::eq) {
3132 auto tensorTy = cast<RankedTensorType>(tensor.getType ());
3233 auto splat = createFullLike (builder, loc, scalar, tensorTy);
33- auto cmp = builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
34- tensor, splat);
34+ auto cmp = builder.create <arith::CmpIOp>(loc, predicate, tensor, splat);
3535 return cmp;
3636}
3737
@@ -512,6 +512,186 @@ struct ClearReadBarrierOpConversion
512512 }
513513};
514514
515+ struct StageWriteForCommitOpConversion
516+ : public ConvertOpToLLVMPattern<tti::ExperimentalStageWriteForCommitOp> {
517+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
518+
519+ LogicalResult matchAndRewrite (tti::ExperimentalStageWriteForCommitOp op,
520+ OpAdaptor adaptor,
521+ ConversionPatternRewriter &b) const override {
522+ Location loc = op.getLoc ();
523+ OpBuilder::InsertionGuard guard (b);
524+ b.setInsertionPoint (op);
525+ if (op.getPred ()) {
526+ auto [prevBlock, ifBlock, thenBlock] =
527+ createIfBlock (b, loc, op.getPred ());
528+ b.setInsertionPointToStart (ifBlock);
529+ }
530+ TypedValue<RankedTensorType> buffers = op.getBuffers ();
531+ RankedTensorType writeCommitsType =
532+ cast<RankedTensorType>(op.getWriteCommitsType ());
533+ Value writeCommits = tti::createLoadScratchMemory (
534+ b, loc, op.getWriteCommits (), writeCommitsType)
535+ ->getResult (0 );
536+ Value buf = createMemDescToI64 (b, loc, getTypeConverter (),
537+ op.getBuf ().getType (), adaptor.getBuf ());
538+
539+ // Gluon pseudo-code:
540+ // write_commits = tl.where(bufs == buf, -1, write_commits)
541+
542+ auto buffersEqBuf = createCmpIntTensorScalar (b, loc, buffers, buf);
543+ auto writeCommitsMinusOne =
544+ tti::createConstIntTensor (b, loc, -1 , writeCommitsType);
545+ writeCommits = b.create <arith::SelectOp>(
546+ loc, buffersEqBuf, writeCommitsMinusOne, writeCommits);
547+ tti::createStoreScratchMemory (b, loc, op.getWriteCommits (), writeCommits,
548+ writeCommitsType);
549+ b.eraseOp (op);
550+ return success ();
551+ }
552+ };
553+
554+ struct CommitWritesOpConversion
555+ : public ConvertOpToLLVMPattern<tti::ExperimentalCommitWritesOp> {
556+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
557+
558+ LogicalResult matchAndRewrite (tti::ExperimentalCommitWritesOp op,
559+ OpAdaptor adaptor,
560+ ConversionPatternRewriter &b) const override {
561+ Location loc = op.getLoc ();
562+ OpBuilder::InsertionGuard guard (b);
563+ b.setInsertionPoint (op);
564+ if (op.getPred ()) {
565+ auto [prevBlock, ifBlock, thenBlock] =
566+ createIfBlock (b, loc, op.getPred ());
567+ b.setInsertionPointToStart (ifBlock);
568+ }
569+ RankedTensorType writeCommitsType =
570+ cast<RankedTensorType>(op.getWriteCommitsType ());
571+ Value writeCommits = tti::createLoadScratchMemory (
572+ b, loc, op.getWriteCommits (), writeCommitsType)
573+ ->getResult (0 );
574+
575+ // Gluon pseudo-code:
576+ // write_commits = tl.where(write_commits > 0, write_commits + 1,
577+ // write_commits) write_commits = tl.where(write_commits == -1, 1,
578+ // write_commits)
579+
580+ Type elementType = writeCommitsType.getElementType ();
581+ Value minusOne = b.create <arith::ConstantOp>(
582+ loc, elementType, b.getIntegerAttr (elementType, -1 ));
583+ Value zero = b.create <arith::ConstantOp>(loc, elementType,
584+ b.getIntegerAttr (elementType, 0 ));
585+ Value writeCommitsOne =
586+ tti::createConstIntTensor (b, loc, 1 , writeCommitsType);
587+
588+ auto writeCommitsGtZero = createCmpIntTensorScalar (
589+ b, loc, writeCommits, zero, arith::CmpIPredicate::sgt);
590+ auto writeCommitsPlusOne =
591+ b.create <arith::AddIOp>(loc, writeCommits, writeCommitsOne);
592+ writeCommits = b.create <arith::SelectOp>(loc, writeCommitsGtZero,
593+ writeCommitsPlusOne, writeCommits);
594+
595+ auto writeCommitsEqMinusOne =
596+ createCmpIntTensorScalar (b, loc, writeCommits, minusOne);
597+ writeCommits = b.create <arith::SelectOp>(loc, writeCommitsEqMinusOne,
598+ writeCommitsOne, writeCommits);
599+ tti::createStoreScratchMemory (b, loc, op.getWriteCommits (), writeCommits,
600+ writeCommitsType);
601+ b.eraseOp (op);
602+ return success ();
603+ }
604+ };
605+
606+ struct ClearWriteCommitsOpConversion
607+ : public ConvertOpToLLVMPattern<tti::ExperimentalClearWriteCommitsOp> {
608+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
609+
610+ LogicalResult matchAndRewrite (tti::ExperimentalClearWriteCommitsOp op,
611+ OpAdaptor adaptor,
612+ ConversionPatternRewriter &b) const override {
613+ Location loc = op.getLoc ();
614+ OpBuilder::InsertionGuard guard (b);
615+ b.setInsertionPoint (op);
616+ if (op.getPred ()) {
617+ auto [prevBlock, ifBlock, thenBlock] =
618+ createIfBlock (b, loc, op.getPred ());
619+ b.setInsertionPointToStart (ifBlock);
620+ }
621+ RankedTensorType writeCommitsType =
622+ cast<RankedTensorType>(op.getWriteCommitsType ());
623+ Value writeCommits = tti::createLoadScratchMemory (
624+ b, loc, op.getWriteCommits (), writeCommitsType)
625+ ->getResult (0 );
626+
627+ // Gluon pseudo-code:
628+ // write_commits = tl.where(write_commits > outstanding_num, 0,
629+ // write_commits)
630+
631+ Type elementType = writeCommitsType.getElementType ();
632+ Value outstandingNum = b.create <arith::ConstantOp>(
633+ loc, elementType,
634+ b.getIntegerAttr (elementType, op.getOutstandingNum ()));
635+ Value writeCommitsZero =
636+ tti::createConstIntTensor (b, loc, 0 , writeCommitsType);
637+ auto writeCommitsGtOutstandingNum = createCmpIntTensorScalar (
638+ b, loc, writeCommits, outstandingNum, arith::CmpIPredicate::sgt);
639+ writeCommits = b.create <arith::SelectOp>(loc, writeCommitsGtOutstandingNum,
640+ writeCommitsZero, writeCommits);
641+ tti::createStoreScratchMemory (b, loc, op.getWriteCommits (), writeCommits,
642+ writeCommitsType);
643+ b.eraseOp (op);
644+ return success ();
645+ }
646+ };
647+
648+ struct CheckWriteCommitOpConversion
649+ : public ConvertOpToLLVMPattern<tti::ExperimentalCheckWriteCommitOp> {
650+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
651+
652+ LogicalResult matchAndRewrite (tti::ExperimentalCheckWriteCommitOp op,
653+ OpAdaptor adaptor,
654+ ConversionPatternRewriter &b) const override {
655+ Location loc = op.getLoc ();
656+ OpBuilder::InsertionGuard guard (b);
657+ b.setInsertionPoint (op);
658+ if (op.getPred ()) {
659+ auto [prevBlock, ifBlock, thenBlock] =
660+ createIfBlock (b, loc, op.getPred ());
661+ b.setInsertionPointToStart (ifBlock);
662+ }
663+ TypedValue<RankedTensorType> buffers = op.getBuffers ();
664+ RankedTensorType writeCommitsType =
665+ cast<RankedTensorType>(op.getWriteCommitsType ());
666+ Value writeCommits = tti::createLoadScratchMemory (
667+ b, loc, op.getWriteCommits (), writeCommitsType)
668+ ->getResult (0 );
669+ Value buf = createMemDescToI64 (b, loc, getTypeConverter (),
670+ op.getBuf ().getType (), adaptor.getBuf ());
671+
672+ // Gluon pseudo-code:
673+ // curr_commits = tl.where(buf == buffers, write_commits, 0)
674+ // tl.device_assert(curr_commits == 0, "Buffer being accessed has
675+ // outstanding writes")
676+
677+ Type elementType = writeCommitsType.getElementType ();
678+ auto buffersEqBuf = createCmpIntTensorScalar (b, loc, buffers, buf);
679+ auto zero = b.create <arith::ConstantOp>(loc, elementType,
680+ b.getIntegerAttr (elementType, 0 ));
681+ auto writeCommitsZero =
682+ tti::createConstIntTensor (b, loc, 0 , writeCommitsType);
683+ auto currCommits = b.create <arith::SelectOp>(
684+ loc, buffersEqBuf, writeCommits, writeCommitsZero);
685+ auto currCommitsEqZero =
686+ createCmpIntTensorScalar (b, loc, currCommits, zero);
687+ b.create <tti::ExperimentalAssertInThreadOp>(
688+ loc, currCommitsEqZero,
689+ b.getStringAttr (" Buffer being accessed has outstanding writes" ), false );
690+ b.eraseOp (op);
691+ return success ();
692+ }
693+ };
694+
515695} // namespace
516696
517697void mlir::triton::populateInstrumentationToLLVMPatterns (
@@ -525,4 +705,8 @@ void mlir::triton::populateInstrumentationToLLVMPatterns(
525705 patterns.add <MarkAsReadOpConversion>(typeConverter);
526706 patterns.add <ClearWriteBarrierOpConversion>(typeConverter);
527707 patterns.add <ClearReadBarrierOpConversion>(typeConverter);
708+ patterns.add <StageWriteForCommitOpConversion>(typeConverter);
709+ patterns.add <CommitWritesOpConversion>(typeConverter);
710+ patterns.add <ClearWriteCommitsOpConversion>(typeConverter);
711+ patterns.add <CheckWriteCommitOpConversion>(typeConverter);
528712}
0 commit comments