Skip to content

Commit be3fdef

Browse files
Merge OpenAI Triton commit 87f5aa4 (#4760)
This PR change the Triton base from 0560390 to 87f5aa4 (Jul 14). Pass rate: 98.46%
2 parents b66fbd6 + 34663ad commit be3fdef

File tree

9 files changed

+673
-106
lines changed

9 files changed

+673
-106
lines changed

include/triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,71 @@ def TTI_ExperimentalClearReadBarrierOp : TTI_Op<"experimental_clear_read_barrier
160160
let hasVerifier = 1;
161161
}
162162

163+
def TTI_ExperimentalStageWriteForCommitOp : TTI_Op<"experimental_stage_write_for_commit", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
164+
let summary = "Preapre to an async copy of a buffer. Staged until commit_group is called.";
165+
let description = [{
166+
Preapre to an async copy of a buffer. Staged until commit_group is called. The implementation will write `-1` to the
167+
`write_commits` tensor under the indices corresponding to the buffer.
168+
}];
169+
let arguments = (ins
170+
TTG_MemDescType:$buf,
171+
TT_Tensor:$buffers,
172+
TT_PtrLike:$writeCommits,
173+
TypeAttr:$writeCommitsType,
174+
Optional<I1>:$pred
175+
);
176+
let assemblyFormat = [{
177+
$buf `{` $buffers `,` $writeCommits `(` $writeCommitsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($buffers) `,` type($writeCommits)
178+
}];
179+
// let hasVerifier = 1;
180+
}
181+
182+
def TTI_ExperimentalCommitWritesOp : TTI_Op<"experimental_commit_writes", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
183+
let summary = "Commit all the staged writes for all the buffers.";
184+
let description = [{
185+
Commit all the staged writes for all the buffers.
186+
}];
187+
let arguments = (ins
188+
TT_PtrLike:$writeCommits,
189+
TypeAttr:$writeCommitsType,
190+
Optional<I1>:$pred);
191+
let assemblyFormat = [{
192+
`{` $writeCommits `(` $writeCommitsType `)` `}` (`,` $pred^)? attr-dict `:` type($writeCommits)
193+
}];
194+
// let hasVerifier = 1;
195+
}
196+
197+
def TTI_ExperimentalClearWriteCommitsOp : TTI_Op<"experimental_clear_write_commits", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
198+
let summary = "Clear all the write commits more distant than `outstandingNum.";
199+
let description = [{
200+
Clear all the write commits more distant than `outstandingNum` from the current thread.
201+
}];
202+
let arguments = (ins
203+
TT_PtrLike:$writeCommits,
204+
TypeAttr:$writeCommitsType,
205+
I32Attr:$outstandingNum,
206+
Optional<I1>:$pred);
207+
let assemblyFormat = [{
208+
`{` $writeCommits `(` $writeCommitsType `)` `}` `,` $outstandingNum (`,` $pred^)? attr-dict `:` type($writeCommits)
209+
}];
210+
// let hasVerifier = 1;
211+
}
212+
213+
def TTI_ExperimentalCheckWriteCommitOp : TTI_Op<"experimental_check_write_commit", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
214+
let summary = "Check if the buffer has an outstanding write commit.";
215+
let description = [{
216+
Check if the buffer has an outstanding write commit.
217+
}];
218+
let arguments = (ins
219+
TTG_MemDescType:$buf,
220+
TT_Tensor:$buffers,
221+
TT_PtrLike:$writeCommits,
222+
TypeAttr:$writeCommitsType,
223+
Optional<I1>:$pred);
224+
let assemblyFormat = [{
225+
$buf `{` $buffers `,` $writeCommits `(` $writeCommitsType `)` `}` (`,` $pred^)? attr-dict `:` type($buf) `,` type($buffers) `,` type($writeCommits)
226+
}];
227+
// let hasVerifier = 1;
228+
}
229+
163230
#endif // TRITONINSTRUMENT_OPS

lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp

Lines changed: 188 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ Value createFullLike(OpBuilder &builder, Location loc, Value scalar,
2626
return builder.create<triton::SplatOp>(loc, tensorTy, scalar);
2727
}
2828

29-
Value createCmpIntTensorScalar(OpBuilder &builder, Location loc, Value tensor,
30-
Value scalar) {
29+
Value createCmpIntTensorScalar(
30+
OpBuilder &builder, Location loc, Value tensor, Value scalar,
31+
arith::CmpIPredicate predicate = arith::CmpIPredicate::eq) {
3132
auto tensorTy = cast<RankedTensorType>(tensor.getType());
3233
auto splat = createFullLike(builder, loc, scalar, tensorTy);
33-
auto cmp = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
34-
tensor, splat);
34+
auto cmp = builder.create<arith::CmpIOp>(loc, predicate, tensor, splat);
3535
return cmp;
3636
}
3737

@@ -512,6 +512,186 @@ struct ClearReadBarrierOpConversion
512512
}
513513
};
514514

515+
struct StageWriteForCommitOpConversion
516+
: public ConvertOpToLLVMPattern<tti::ExperimentalStageWriteForCommitOp> {
517+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
518+
519+
LogicalResult matchAndRewrite(tti::ExperimentalStageWriteForCommitOp op,
520+
OpAdaptor adaptor,
521+
ConversionPatternRewriter &b) const override {
522+
Location loc = op.getLoc();
523+
OpBuilder::InsertionGuard guard(b);
524+
b.setInsertionPoint(op);
525+
if (op.getPred()) {
526+
auto [prevBlock, ifBlock, thenBlock] =
527+
createIfBlock(b, loc, op.getPred());
528+
b.setInsertionPointToStart(ifBlock);
529+
}
530+
TypedValue<RankedTensorType> buffers = op.getBuffers();
531+
RankedTensorType writeCommitsType =
532+
cast<RankedTensorType>(op.getWriteCommitsType());
533+
Value writeCommits = tti::createLoadScratchMemory(
534+
b, loc, op.getWriteCommits(), writeCommitsType)
535+
->getResult(0);
536+
Value buf = createMemDescToI64(b, loc, getTypeConverter(),
537+
op.getBuf().getType(), adaptor.getBuf());
538+
539+
// Gluon pseudo-code:
540+
// write_commits = tl.where(bufs == buf, -1, write_commits)
541+
542+
auto buffersEqBuf = createCmpIntTensorScalar(b, loc, buffers, buf);
543+
auto writeCommitsMinusOne =
544+
tti::createConstIntTensor(b, loc, -1, writeCommitsType);
545+
writeCommits = b.create<arith::SelectOp>(
546+
loc, buffersEqBuf, writeCommitsMinusOne, writeCommits);
547+
tti::createStoreScratchMemory(b, loc, op.getWriteCommits(), writeCommits,
548+
writeCommitsType);
549+
b.eraseOp(op);
550+
return success();
551+
}
552+
};
553+
554+
struct CommitWritesOpConversion
555+
: public ConvertOpToLLVMPattern<tti::ExperimentalCommitWritesOp> {
556+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
557+
558+
LogicalResult matchAndRewrite(tti::ExperimentalCommitWritesOp op,
559+
OpAdaptor adaptor,
560+
ConversionPatternRewriter &b) const override {
561+
Location loc = op.getLoc();
562+
OpBuilder::InsertionGuard guard(b);
563+
b.setInsertionPoint(op);
564+
if (op.getPred()) {
565+
auto [prevBlock, ifBlock, thenBlock] =
566+
createIfBlock(b, loc, op.getPred());
567+
b.setInsertionPointToStart(ifBlock);
568+
}
569+
RankedTensorType writeCommitsType =
570+
cast<RankedTensorType>(op.getWriteCommitsType());
571+
Value writeCommits = tti::createLoadScratchMemory(
572+
b, loc, op.getWriteCommits(), writeCommitsType)
573+
->getResult(0);
574+
575+
// Gluon pseudo-code:
576+
// write_commits = tl.where(write_commits > 0, write_commits + 1,
577+
// write_commits) write_commits = tl.where(write_commits == -1, 1,
578+
// write_commits)
579+
580+
Type elementType = writeCommitsType.getElementType();
581+
Value minusOne = b.create<arith::ConstantOp>(
582+
loc, elementType, b.getIntegerAttr(elementType, -1));
583+
Value zero = b.create<arith::ConstantOp>(loc, elementType,
584+
b.getIntegerAttr(elementType, 0));
585+
Value writeCommitsOne =
586+
tti::createConstIntTensor(b, loc, 1, writeCommitsType);
587+
588+
auto writeCommitsGtZero = createCmpIntTensorScalar(
589+
b, loc, writeCommits, zero, arith::CmpIPredicate::sgt);
590+
auto writeCommitsPlusOne =
591+
b.create<arith::AddIOp>(loc, writeCommits, writeCommitsOne);
592+
writeCommits = b.create<arith::SelectOp>(loc, writeCommitsGtZero,
593+
writeCommitsPlusOne, writeCommits);
594+
595+
auto writeCommitsEqMinusOne =
596+
createCmpIntTensorScalar(b, loc, writeCommits, minusOne);
597+
writeCommits = b.create<arith::SelectOp>(loc, writeCommitsEqMinusOne,
598+
writeCommitsOne, writeCommits);
599+
tti::createStoreScratchMemory(b, loc, op.getWriteCommits(), writeCommits,
600+
writeCommitsType);
601+
b.eraseOp(op);
602+
return success();
603+
}
604+
};
605+
606+
struct ClearWriteCommitsOpConversion
607+
: public ConvertOpToLLVMPattern<tti::ExperimentalClearWriteCommitsOp> {
608+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
609+
610+
LogicalResult matchAndRewrite(tti::ExperimentalClearWriteCommitsOp op,
611+
OpAdaptor adaptor,
612+
ConversionPatternRewriter &b) const override {
613+
Location loc = op.getLoc();
614+
OpBuilder::InsertionGuard guard(b);
615+
b.setInsertionPoint(op);
616+
if (op.getPred()) {
617+
auto [prevBlock, ifBlock, thenBlock] =
618+
createIfBlock(b, loc, op.getPred());
619+
b.setInsertionPointToStart(ifBlock);
620+
}
621+
RankedTensorType writeCommitsType =
622+
cast<RankedTensorType>(op.getWriteCommitsType());
623+
Value writeCommits = tti::createLoadScratchMemory(
624+
b, loc, op.getWriteCommits(), writeCommitsType)
625+
->getResult(0);
626+
627+
// Gluon pseudo-code:
628+
// write_commits = tl.where(write_commits > outstanding_num, 0,
629+
// write_commits)
630+
631+
Type elementType = writeCommitsType.getElementType();
632+
Value outstandingNum = b.create<arith::ConstantOp>(
633+
loc, elementType,
634+
b.getIntegerAttr(elementType, op.getOutstandingNum()));
635+
Value writeCommitsZero =
636+
tti::createConstIntTensor(b, loc, 0, writeCommitsType);
637+
auto writeCommitsGtOutstandingNum = createCmpIntTensorScalar(
638+
b, loc, writeCommits, outstandingNum, arith::CmpIPredicate::sgt);
639+
writeCommits = b.create<arith::SelectOp>(loc, writeCommitsGtOutstandingNum,
640+
writeCommitsZero, writeCommits);
641+
tti::createStoreScratchMemory(b, loc, op.getWriteCommits(), writeCommits,
642+
writeCommitsType);
643+
b.eraseOp(op);
644+
return success();
645+
}
646+
};
647+
648+
struct CheckWriteCommitOpConversion
649+
: public ConvertOpToLLVMPattern<tti::ExperimentalCheckWriteCommitOp> {
650+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
651+
652+
LogicalResult matchAndRewrite(tti::ExperimentalCheckWriteCommitOp op,
653+
OpAdaptor adaptor,
654+
ConversionPatternRewriter &b) const override {
655+
Location loc = op.getLoc();
656+
OpBuilder::InsertionGuard guard(b);
657+
b.setInsertionPoint(op);
658+
if (op.getPred()) {
659+
auto [prevBlock, ifBlock, thenBlock] =
660+
createIfBlock(b, loc, op.getPred());
661+
b.setInsertionPointToStart(ifBlock);
662+
}
663+
TypedValue<RankedTensorType> buffers = op.getBuffers();
664+
RankedTensorType writeCommitsType =
665+
cast<RankedTensorType>(op.getWriteCommitsType());
666+
Value writeCommits = tti::createLoadScratchMemory(
667+
b, loc, op.getWriteCommits(), writeCommitsType)
668+
->getResult(0);
669+
Value buf = createMemDescToI64(b, loc, getTypeConverter(),
670+
op.getBuf().getType(), adaptor.getBuf());
671+
672+
// Gluon pseudo-code:
673+
// curr_commits = tl.where(buf == buffers, write_commits, 0)
674+
// tl.device_assert(curr_commits == 0, "Buffer being accessed has
675+
// outstanding writes")
676+
677+
Type elementType = writeCommitsType.getElementType();
678+
auto buffersEqBuf = createCmpIntTensorScalar(b, loc, buffers, buf);
679+
auto zero = b.create<arith::ConstantOp>(loc, elementType,
680+
b.getIntegerAttr(elementType, 0));
681+
auto writeCommitsZero =
682+
tti::createConstIntTensor(b, loc, 0, writeCommitsType);
683+
auto currCommits = b.create<arith::SelectOp>(
684+
loc, buffersEqBuf, writeCommits, writeCommitsZero);
685+
auto currCommitsEqZero =
686+
createCmpIntTensorScalar(b, loc, currCommits, zero);
687+
b.create<tti::ExperimentalAssertInThreadOp>(
688+
loc, currCommitsEqZero,
689+
b.getStringAttr("Buffer being accessed has outstanding writes"), false);
690+
b.eraseOp(op);
691+
return success();
692+
}
693+
};
694+
515695
} // namespace
516696

517697
void mlir::triton::populateInstrumentationToLLVMPatterns(
@@ -525,4 +705,8 @@ void mlir::triton::populateInstrumentationToLLVMPatterns(
525705
patterns.add<MarkAsReadOpConversion>(typeConverter);
526706
patterns.add<ClearWriteBarrierOpConversion>(typeConverter);
527707
patterns.add<ClearReadBarrierOpConversion>(typeConverter);
708+
patterns.add<StageWriteForCommitOpConversion>(typeConverter);
709+
patterns.add<CommitWritesOpConversion>(typeConverter);
710+
patterns.add<ClearWriteCommitsOpConversion>(typeConverter);
711+
patterns.add<CheckWriteCommitOpConversion>(typeConverter);
528712
}

lib/Dialect/TritonInstrument/IR/Utility.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ TypedValue<RankedTensorType> createConstIntTensor(OpBuilder &builder,
1212
Location loc, int val,
1313
RankedTensorType tensorType) {
1414
auto denseAttr = DenseElementsAttr::get(
15-
tensorType,
16-
APInt(tensorType.getElementType().getIntOrFloatBitWidth(), val));
15+
tensorType, APInt(tensorType.getElementType().getIntOrFloatBitWidth(),
16+
val, /*isSigned=*/true));
1717
return cast<TypedValue<RankedTensorType>>(
1818
builder.create<arith::ConstantOp>(loc, tensorType, denseAttr)
1919
.getResult());

0 commit comments

Comments
 (0)