Skip to content

Commit 38563a8

Browse files
Merge commit 'bc4675aaa291097d96dc21183296c595947a27bc'
2 parents 0b078e8 + bc4675a commit 38563a8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+2122
-1660
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 341 additions & 182 deletions
Large diffs are not rendered by default.

lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,18 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
1818
matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor,
1919
ConversionPatternRewriter &rewriter) const override {
2020
auto loc = op.getLoc();
21+
auto b = TritonLLVMOpBuilder(loc, rewriter);
2122
auto ctx = rewriter.getContext();
2223
auto typeConverter = getTypeConverter();
2324
auto elems = unpackLLElements(loc, adaptor.getCondition(), rewriter);
2425
auto elemTy = elems[0].getType();
25-
Value condition = int_val(elemTy.getIntOrFloatBitWidth(), 0);
26+
Value condition = b.int_val(elemTy.getIntOrFloatBitWidth(), 0);
2627
for (auto elem : elems) {
2728
if (elemTy.isSignedInteger() || elemTy.isSignlessInteger()) {
28-
condition =
29-
or_(condition,
30-
icmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
31-
loc, elemTy, rewriter.getZeroAttr(elemTy))));
29+
condition = b.or_(
30+
condition,
31+
b.icmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
32+
loc, elemTy, rewriter.getZeroAttr(elemTy))));
3233
} else {
3334
assert(false && "Unsupported type for assert");
3435
return failure();
@@ -41,7 +42,7 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
4142
// tensor in those two operations may have different layout we need to
4243
// make sure all the threads are done executing the assert before going to
4344
// the next op.
44-
barrier();
45+
b.barrier();
4546
}
4647
rewriter.eraseOp(op);
4748
return success();

lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
1313
matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor,
1414
ConversionPatternRewriter &rewriter) const override {
1515
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
16+
auto loc = op.getLoc();
17+
auto b = TritonLLVMOpBuilder(loc, rewriter);
1618
if (funcOp->hasAttr("nvvm.kernel")) {
1719
// A GPU kernel
1820
if (op.getNumOperands() > 0) {
@@ -34,10 +36,9 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
3436
funcOp.getResultTypes());
3537
Value packedResults =
3638
rewriter.create<LLVM::UndefOp>(op.getLoc(), packedResultsTy);
37-
auto loc = op.getLoc();
3839
for (auto it : llvm::enumerate(adaptor.getOperands())) {
39-
packedResults = insert_val(packedResultsTy, packedResults, it.value(),
40-
it.index());
40+
packedResults = b.insert_val(packedResultsTy, packedResults,
41+
it.value(), it.index());
4142
}
4243
newOp = rewriter.create<LLVM::ReturnOp>(op.getLoc(), packedResults);
4344
}
@@ -78,6 +79,7 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
7879
// Get the last argument of the caller, which is the current stack pointer
7980
// of shared memory and append it to the operands of the callOp.
8081
auto loc = callOp.getLoc();
82+
auto b = TritonLLVMOpBuilder(loc, rewriter);
8183
auto caller = callOp->getParentOfType<FunctionOpInterface>();
8284
auto promotedOperands = this->getTypeConverter()->promoteOperands(
8385
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
@@ -95,7 +97,7 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
9597
Value opOffsetVal;
9698
if (opOffsetAttr) {
9799
auto opOffset = opOffsetAttr.getValue().getZExtValue();
98-
opOffsetVal = i32_val(opOffset);
100+
opOffsetVal = b.i32_val(opOffset);
99101
}
100102

101103
promotedOperands.push_back(

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ struct ConvertLayoutOpConversion
6262
ArrayRef<unsigned> origRepShape,
6363
ArrayRef<unsigned> outOrd, SmallVector<Value> &vals,
6464
Value smemBase) const {
65+
auto b = TritonLLVMOpBuilder(loc, rewriter);
6566
auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
6667
auto layout = type.getEncoding();
6768
auto rank = type.getRank();
@@ -110,29 +111,29 @@ struct ConvertLayoutOpConversion
110111
Value offset = LLVM::linearize(rewriter, loc, multiDimOffsetWrapped,
111112
paddedRepShape, outOrd);
112113
auto elemPtrTy = smemBase.getType();
113-
Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset);
114+
Value ptr = b.gep(elemPtrTy, llvmElemTy, smemBase, offset);
114115
auto vecTy = vec_ty(llvmElemTy, vec);
115116
if (stNotRd) {
116-
Value valVec = undef(vecTy);
117+
Value valVec = b.undef(vecTy);
117118
for (unsigned v = 0; v < vec; ++v) {
118119
auto currVal = vals[elemId + linearCTAId * accumSizePerThread + v];
119120
if (isInt1)
120-
currVal = zext(llvmElemTy, currVal);
121+
currVal = b.zext(llvmElemTy, currVal);
121122
else if (isPtr)
122-
currVal = ptrtoint(llvmElemTy, currVal);
123-
valVec = insert_element(vecTy, valVec, currVal, i32_val(v));
123+
currVal = b.ptrtoint(llvmElemTy, currVal);
124+
valVec = b.insert_element(vecTy, valVec, currVal, b.i32_val(v));
124125
}
125-
store(valVec, ptr);
126+
b.store(valVec, ptr);
126127
} else {
127-
Value valVec = load(vecTy, ptr);
128+
Value valVec = b.load(vecTy, ptr);
128129
for (unsigned v = 0; v < vec; ++v) {
129-
Value currVal = extract_element(llvmElemTy, valVec, i32_val(v));
130+
Value currVal = b.extract_element(llvmElemTy, valVec, b.i32_val(v));
130131
if (isInt1)
131-
currVal = icmp_ne(currVal,
132-
rewriter.create<LLVM::ConstantOp>(
133-
loc, i8_ty, rewriter.getI8IntegerAttr(0)));
132+
currVal = b.icmp_ne(
133+
currVal, rewriter.create<LLVM::ConstantOp>(
134+
loc, i8_ty, rewriter.getI8IntegerAttr(0)));
134135
else if (isPtr)
135-
currVal = inttoptr(llvmElemTyOrig, currVal);
136+
currVal = b.inttoptr(llvmElemTyOrig, currVal);
136137
vals[elemId + linearCTAId * accumSizePerThread + v] = currVal;
137138
}
138139
}
@@ -146,6 +147,7 @@ struct ConvertLayoutOpConversion
146147
ConversionPatternRewriter &rewriter,
147148
const TargetInfoBase &targetInfo) const {
148149
auto loc = op.getLoc();
150+
auto b = TritonLLVMOpBuilder(loc, rewriter);
149151
auto typeConverter = getTypeConverter();
150152
RankedTensorType srcTy = op.getSrc().getType();
151153
RankedTensorType dstTy = op.getType();
@@ -205,12 +207,12 @@ struct ConvertLayoutOpConversion
205207
auto multiDimRepId =
206208
getMultiDimIndex<unsigned>(repId, numReplicates, outOrd);
207209
if (repId != 0) {
208-
barrier();
210+
b.barrier();
209211
}
210212
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
211213
multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd,
212214
vals, smemBase);
213-
barrier();
215+
b.barrier();
214216
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep,
215217
multiDimRepId, outVec, paddedRepShape, origRepShape,
216218
outOrd, outVals, smemBase);
@@ -355,6 +357,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
355357
ConversionPatternRewriter &rewriter) const {
356358
MLIRContext *ctx = op.getContext();
357359
auto loc = op.getLoc();
360+
auto b = TritonLLVMOpBuilder(loc, rewriter);
358361
auto srcTy = op.getSrc().getType();
359362
auto dstTy = op.getType();
360363

@@ -399,9 +402,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
399402
// Munge input values
400403
for (const auto &it : llvm::enumerate(inVals)) {
401404
if (isSubByteInt) {
402-
inVals[it.index()] = zext(llvmElemTy, it.value());
405+
inVals[it.index()] = b.zext(llvmElemTy, it.value());
403406
} else if (isPtr) {
404-
inVals[it.index()] = ptrtoint(llvmElemTy, it.value());
407+
inVals[it.index()] = b.ptrtoint(llvmElemTy, it.value());
405408
}
406409
}
407410

@@ -417,9 +420,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
417420
// Unmunge output values
418421
for (const auto &it : llvm::enumerate(outVals)) {
419422
if (isSubByteInt) {
420-
outVals[it.index()] = trunc(llvmElemTyOrig, it.value());
423+
outVals[it.index()] = b.trunc(llvmElemTyOrig, it.value());
421424
} else if (isPtr) {
422-
outVals[it.index()] = inttoptr(llvmElemTyOrig, it.value());
425+
outVals[it.index()] = b.inttoptr(llvmElemTyOrig, it.value());
423426
}
424427
}
425428

@@ -443,6 +446,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
443446
ConversionPatternRewriter &rewriter) const {
444447
MLIRContext *ctx = op.getContext();
445448
auto loc = op.getLoc();
449+
auto b = TritonLLVMOpBuilder(loc, rewriter);
446450

447451
StringAttr kRegister = str_attr("register");
448452
StringAttr kLane = str_attr("lane");
@@ -452,9 +456,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
452456
StringAttr kIteration = str_attr("iteration");
453457

454458
Value threadId = getThreadId(rewriter, loc);
455-
Value threadsPerWarp = i32_val(srcLayout.getInDimSize(kLane));
456-
Value laneId = urem(threadId, threadsPerWarp);
457-
Value warpId = udiv(threadId, threadsPerWarp);
459+
Value threadsPerWarp = b.i32_val(srcLayout.getInDimSize(kLane));
460+
Value laneId = b.urem(threadId, threadsPerWarp);
461+
Value warpId = b.udiv(threadId, threadsPerWarp);
458462

459463
auto scratchConfig =
460464
getScratchConfigForCvt(op.getSrc().getType(), op.getType());
@@ -541,37 +545,38 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
541545
{kWarp, 0},
542546
{kBlock, 0}})[0]
543547
.second;
544-
Value offset = xor_(regBase, i32_val(regIdx));
548+
Value offset = b.xor_(regBase, b.i32_val(regIdx));
545549
if (paddedSize > 0) {
546550
assert(llvm::isPowerOf2_32(paddedStride));
547551
assert(llvm::isPowerOf2_32(paddedSize));
548552
auto rshiftVal = llvm::Log2_32(paddedStride);
549553
auto lshiftVal = llvm::Log2_32(paddedSize);
550-
offset = add(shl(lshr(offset, i32_val(rshiftVal)), i32_val(lshiftVal)),
551-
offset);
554+
offset = b.add(
555+
b.shl(b.lshr(offset, b.i32_val(rshiftVal)), b.i32_val(lshiftVal)),
556+
offset);
552557
}
553-
auto vecAddr = gep(sharedPtrTy, elemTy, smemBase, offset);
558+
auto vecAddr = b.gep(sharedPtrTy, elemTy, smemBase, offset);
554559
vecAddr.setInbounds(true);
555560
return vecAddr;
556561
};
557562

558563
auto storeBase = applyLinearLayout(loc, rewriter, shmemStoreLayout,
559-
{{kRegister, i32_val(0)},
564+
{{kRegister, b.i32_val(0)},
560565
{kLane, laneId},
561566
{kWarp, warpId},
562-
{kBlock, i32_val(0)}})[0]
567+
{kBlock, b.i32_val(0)}})[0]
563568
.second;
564569
auto loadBase = applyLinearLayout(loc, rewriter, shmemLoadLayout,
565-
{{kRegister, i32_val(0)},
570+
{{kRegister, b.i32_val(0)},
566571
{kLane, laneId},
567572
{kWarp, warpId},
568-
{kBlock, i32_val(0)}})[0]
573+
{kBlock, b.i32_val(0)}})[0]
569574
.second;
570575
// register idx -> Value
571576
llvm::MapVector<int, Value> outVals;
572577
for (int i = 0; i < iterations; i++) {
573578
if (i != 0)
574-
barrier();
579+
b.barrier();
575580

576581
auto &inRegs = inRegsForIter[i];
577582
auto &outRegs = outRegsForIter[i];
@@ -591,19 +596,19 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
591596
targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec);
592597
} else {
593598
targetInfo.storeDShared(rewriter, loc, vecAddr, std::nullopt, valsVec,
594-
/*pred=*/true_val());
599+
/*pred=*/b.true_val());
595600
}
596601
}
597602

598-
barrier();
603+
b.barrier();
599604

600605
for (int j = 0; j < outSize / iterations; j += scratchConfig.outVec) {
601606
auto outRegSlice = outRegs[j];
602607
auto vecAddr = getVecAddr(shmemLoadLayout, loadBase, outRegSlice);
603608
Value valsVec =
604609
targetInfo.loadDShared(rewriter, loc, vecAddr, std::nullopt,
605610
vec_ty(elemTy, scratchConfig.outVec),
606-
/*pred=*/true_val());
611+
/*pred=*/b.true_val());
607612
for (Value v : unpackLLVector(loc, valsVec, rewriter))
608613
outVals[outRegSlice++] = v;
609614
}
@@ -646,6 +651,7 @@ void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp(
646651
ConversionPatternRewriter &rewriter) const {
647652
MLIRContext *ctx = op.getContext();
648653
Location loc = op.getLoc();
654+
auto b = TritonLLVMOpBuilder(loc, rewriter);
649655
StringAttr kRegister = str_attr("register");
650656
StringAttr kLane = str_attr("lane");
651657
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
@@ -657,8 +663,8 @@ void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp(
657663
SmallVector<Value> shflOuts(Cp.getInDimSize(kRegister));
658664

659665
Value threadId = getThreadId(rewriter, loc);
660-
Value threadsPerWarp = i32_val(Cp.getInDimSize(kLane));
661-
Value laneId = urem(threadId, threadsPerWarp);
666+
Value threadsPerWarp = b.i32_val(Cp.getInDimSize(kLane));
667+
Value laneId = b.urem(threadId, threadsPerWarp);
662668

663669
// Emit one shuffle per destination register.
664670
for (int i : llvm::seq(shflOuts.size())) {
@@ -667,22 +673,22 @@ void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp(
667673
// At the same time, for each register, P1 returns the source value index
668674
// to provide as the shuffle value.
669675
auto out = applyLinearLayout(loc, rewriter, P1,
670-
{{kLane, laneId}, {kRegister, i32_val(i)}});
676+
{{kLane, laneId}, {kRegister, b.i32_val(i)}});
671677
assert(out.size() == 1);
672678
Value srcRegIdx = out.front().second;
673679
// The size of the input lane dimension is the number of selects to emit.
674680
// TODO(jeff): For dtypes smaller than i32, we can use byte permutes and
675681
// shuffle multiple values at a time.
676-
Value shflSrc = undef(srcValues.front().getType());
682+
Value shflSrc = b.undef(srcValues.front().getType());
677683
for (int j : llvm::seq(reducedP1.getInDimSize(kLane))) {
678684
int32_t check =
679685
reducedP1.apply({{kLane, j}, {kRegister, i}}).front().second;
680-
shflSrc =
681-
select(icmp_eq(srcRegIdx, i32_val(check)), srcValues[check], shflSrc);
686+
shflSrc = b.select(b.icmp_eq(srcRegIdx, b.i32_val(check)),
687+
srcValues[check], shflSrc);
682688
}
683689

684690
out = applyLinearLayout(loc, rewriter, Cp,
685-
{{kLane, laneId}, {kRegister, i32_val(i)}});
691+
{{kLane, laneId}, {kRegister, b.i32_val(i)}});
686692
assert(out.size() == 1);
687693
Value shflIdx = out.front().second;
688694
shflOuts[i] = targetInfo.shuffleIdx(rewriter, loc, shflSrc, shflIdx);
@@ -693,16 +699,16 @@ void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp(
693699
// selects.
694700
SmallVector<Value> results(shflOuts.size());
695701
for (int i : llvm::seq(results.size())) {
696-
Value result = undef(srcValues.front().getType());
702+
Value result = b.undef(srcValues.front().getType());
697703

698704
auto out = applyLinearLayout(loc, rewriter, P2inv,
699-
{{kLane, laneId}, {kRegister, i32_val(i)}});
705+
{{kLane, laneId}, {kRegister, b.i32_val(i)}});
700706
Value resultIdx = out.front().second;
701707
for (int j : llvm::seq(reducedP2.getInDimSize(kLane))) {
702708
int32_t check =
703709
reducedP2.apply({{kLane, j}, {kRegister, i}}).front().second;
704-
result =
705-
select(icmp_eq(resultIdx, i32_val(check)), shflOuts[check], result);
710+
result = b.select(b.icmp_eq(resultIdx, b.i32_val(check)), shflOuts[check],
711+
result);
706712
}
707713
results[i] = result;
708714
}

0 commit comments

Comments
 (0)