Skip to content

Commit 2b5b6f7

Browse files
[Intel] replace TritonGPUToLLVM/Utility.h macros with TritonLLVMOpBuilder
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 38563a8 commit 2b5b6f7

21 files changed

+1343
-1172
lines changed

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 53 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op,
183183
MLIRContext *ctx = rewriter.getContext();
184184
VectorType resType = op.getRes().getType();
185185
Location loc = op->getLoc();
186+
auto b = TritonLLVMOpBuilder(loc, rewriter);
186187

187188
Value ptr = op.getPtr();
188189
Value baseWidth = op.getBaseWidth();
@@ -199,7 +200,7 @@ createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op,
199200

200201
// The IGC intrinsic requires the first argument be int64
201202
ptr = rewriter.create<LLVM::PtrToIntOp>(loc, int64Ty, ptr);
202-
Value one = i32_val(1);
203+
Value one = b.i32_val(1);
203204

204205
SmallVector<Type> argTypes{int64Ty,
205206
baseWidth.getType(),
@@ -216,18 +217,18 @@ createGenISA2DBlockRead(TritonGEN::Matrix2DBlockLoadOp op,
216217
int32Ty};
217218

218219
SmallVector<Value> args{ptr,
219-
sub(baseWidth, one),
220-
sub(baseHeight, one),
221-
sub(basePitch, one),
220+
b.sub(baseWidth, one),
221+
b.sub(baseHeight, one),
222+
b.sub(basePitch, one),
222223
x,
223224
y,
224-
i32_val(op.getElemSizeInBits()),
225-
i32_val(op.getTileWidth()),
226-
i32_val(op.getTileHeight()),
227-
i32_val(op.getVBlocks()),
228-
i1_val(op.getTranspose()),
229-
i1_val(op.getVnniTransform()),
230-
i32_val(static_cast<int>(op.getCacheControl()))};
225+
b.i32_val(op.getElemSizeInBits()),
226+
b.i32_val(op.getTileWidth()),
227+
b.i32_val(op.getTileHeight()),
228+
b.i32_val(op.getVBlocks()),
229+
b.i1_val(op.getTranspose()),
230+
b.i1_val(op.getVnniTransform()),
231+
b.i32_val(static_cast<int>(op.getCacheControl()))};
231232

232233
LLVM::CallOp call = createDeviceFunctionCall(
233234
rewriter, funcName, resType, argTypes, args, {}, noUnwindWillReturnAttrs);
@@ -291,6 +292,7 @@ createGenISA2DBlockWrite(TritonGEN::Matrix2DBlockStoreOp op,
291292
ConversionPatternRewriter &rewriter) {
292293
MLIRContext *ctx = rewriter.getContext();
293294
Location loc = op->getLoc();
295+
auto b = TritonLLVMOpBuilder(loc, rewriter);
294296

295297
// The IGC intrinsic requires the first argument be int64
296298
Value ptr = op.getPtr();
@@ -305,7 +307,7 @@ createGenISA2DBlockWrite(TritonGEN::Matrix2DBlockStoreOp op,
305307
VectorType storeValType = op.getStoredVal().getType();
306308
std::string funcName =
307309
"llvm.genx.GenISA.LSC2DBlockWrite." + getGenISATypeMangling(storeValType);
308-
Value one = i32_val(1);
310+
Value one = b.i32_val(1);
309311

310312
SmallVector<Type> argTypes{
311313
int_ty(64), baseWidth.getType(), baseHeight.getType(),
@@ -314,18 +316,18 @@ createGenISA2DBlockWrite(TritonGEN::Matrix2DBlockStoreOp op,
314316
int_ty(32), int_ty(1), int_ty(1),
315317
int_ty(32), storeVal.getType()};
316318
SmallVector<Value> args{ptr,
317-
sub(baseWidth, one),
318-
sub(baseHeight, one),
319-
sub(basePitch, one),
319+
b.sub(baseWidth, one),
320+
b.sub(baseHeight, one),
321+
b.sub(basePitch, one),
320322
x,
321323
y,
322-
i32_val(op.getElemSizeInBits()),
323-
i32_val(op.getTileWidth()),
324-
i32_val(op.getTileHeight()),
325-
i32_val(op.getVBlocks()),
326-
i1_val(false), // transpose
327-
i1_val(false), // vnniTransform
328-
i32_val(static_cast<int>(op.getCacheControl())),
324+
b.i32_val(op.getElemSizeInBits()),
325+
b.i32_val(op.getTileWidth()),
326+
b.i32_val(op.getTileHeight()),
327+
b.i32_val(op.getVBlocks()),
328+
b.i1_val(false), // transpose
329+
b.i1_val(false), // vnniTransform
330+
b.i32_val(static_cast<int>(op.getCacheControl())),
329331
storeVal};
330332

331333
LLVM::CallOp call =
@@ -339,6 +341,7 @@ createGenISA2DBlockPrefetch(TritonGEN::Matrix2DBlockPrefetchOp op,
339341
ConversionPatternRewriter &rewriter) {
340342
MLIRContext *ctx = rewriter.getContext();
341343
Location loc = op->getLoc();
344+
auto b = TritonLLVMOpBuilder(loc, rewriter);
342345

343346
// The IGC intrinsic requires the first argument be int64
344347
Value ptr = op.getPtr();
@@ -348,7 +351,7 @@ createGenISA2DBlockPrefetch(TritonGEN::Matrix2DBlockPrefetchOp op,
348351
Value basePitch = op.getBasePitch();
349352
Value x = op.getX();
350353
Value y = op.getY();
351-
Value one = i32_val(1);
354+
Value one = b.i32_val(1);
352355

353356
SmallVector<Type> argTypes{
354357
int_ty(64), baseWidth.getType(), baseHeight.getType(),
@@ -357,18 +360,18 @@ createGenISA2DBlockPrefetch(TritonGEN::Matrix2DBlockPrefetchOp op,
357360
int_ty(32), int_ty(1), int_ty(1),
358361
int_ty(32)};
359362
SmallVector<Value> args{ptr,
360-
sub(baseWidth, one),
361-
sub(baseHeight, one),
362-
sub(basePitch, one),
363+
b.sub(baseWidth, one),
364+
b.sub(baseHeight, one),
365+
b.sub(basePitch, one),
363366
x,
364367
y,
365-
i32_val(op.getElemSizeInBits()),
366-
i32_val(op.getTileWidth()),
367-
i32_val(op.getTileHeight()),
368-
i32_val(op.getVBlocks()),
369-
i1_val(false), // transpose
370-
i1_val(false), // vnniTransform
371-
i32_val(static_cast<int>(op.getCacheControl()))};
368+
b.i32_val(op.getElemSizeInBits()),
369+
b.i32_val(op.getTileWidth()),
370+
b.i32_val(op.getTileHeight()),
371+
b.i32_val(op.getVBlocks()),
372+
b.i1_val(false), // transpose
373+
b.i1_val(false), // vnniTransform
374+
b.i32_val(static_cast<int>(op.getCacheControl()))};
372375

373376
const StringLiteral funcName = "llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid";
374377
return createDeviceFunctionCall(rewriter, funcName, void_ty(ctx), {argTypes},
@@ -485,11 +488,12 @@ struct TritonMatrix2DBlockLoadLowering
485488
ConversionPatternRewriter &rewriter) const override {
486489
MLIRContext *ctx = rewriter.getContext();
487490
Location loc = op->getLoc();
491+
auto b = TritonLLVMOpBuilder(loc, rewriter);
488492
VectorType resType = op.getRes().getType();
489493

490494
auto dest = rewriter.create<LLVM::AllocaOp>(
491495
loc, ptr_ty(ctx), resType.getElementType(),
492-
i32_val(resType.getNumElements()));
496+
b.i32_val(resType.getNumElements()));
493497
std::string fnName = "intel_sub_group_2d_block_read_";
494498
if (op.getVnniTransform())
495499
fnName += "transform_";
@@ -503,9 +507,10 @@ struct TritonMatrix2DBlockLoadLowering
503507
fnName +=
504508
intel::getTypeMangling(resType.getElementType(), /*isUnsigned=*/true);
505509
VectorType vecType = vec_ty(i32_ty, 2);
506-
Value byteCoord = insert_element(
507-
vecType, insert_element(vecType, undef(vecType), op.getX(), i32_val(0)),
508-
op.getY(), i32_val(1));
510+
Value byteCoord = b.insert_element(
511+
vecType,
512+
b.insert_element(vecType, b.undef(vecType), op.getX(), b.i32_val(0)),
513+
op.getY(), b.i32_val(1));
509514
SmallVector<Type> argTypes{ptr_ty(ctx, 1), i32_ty, i32_ty,
510515
i32_ty, vecType, ptr_ty(ctx)};
511516
SmallVector<Value> args{op.getPtr(), op.getBaseWidth(),
@@ -545,11 +550,12 @@ struct TritonMatrix2DBlockStoreLowering
545550
ConversionPatternRewriter &rewriter) const override {
546551
MLIRContext *ctx = rewriter.getContext();
547552
Location loc = op->getLoc();
553+
auto b = TritonLLVMOpBuilder(loc, rewriter);
548554

549555
VectorType storeValType = op.getStoredVal().getType();
550556
auto storeValPtr = rewriter.create<LLVM::AllocaOp>(
551557
loc, ptr_ty(ctx), storeValType.getElementType(),
552-
i32_val(storeValType.getNumElements()));
558+
b.i32_val(storeValType.getNumElements()));
553559
rewriter.create<LLVM::StoreOp>(loc, op.getStoredVal(), storeValPtr);
554560

555561
std::string fnName = "intel_sub_group_2d_block_write_";
@@ -565,9 +571,10 @@ struct TritonMatrix2DBlockStoreLowering
565571
: "h";
566572

567573
VectorType vecType = vec_ty(i32_ty, 2);
568-
Value byteCoord = insert_element(
569-
vecType, insert_element(vecType, undef(vecType), op.getX(), i32_val(0)),
570-
op.getY(), i32_val(1));
574+
Value byteCoord = b.insert_element(
575+
vecType,
576+
b.insert_element(vecType, b.undef(vecType), op.getX(), b.i32_val(0)),
577+
op.getY(), b.i32_val(1));
571578
SmallVector<Type> argTypes{ptr_ty(ctx, 1), i32_ty, i32_ty,
572579
i32_ty, vecType, ptr_ty(ctx)};
573580
SmallVector<Value> args{op.getPtr(), op.getBaseWidth(),
@@ -607,16 +614,18 @@ struct TritonMatrix2DBlockPrefetchLowering
607614
ConversionPatternRewriter &rewriter) const override {
608615
MLIRContext *ctx = rewriter.getContext();
609616
Location loc = op->getLoc();
617+
auto b = TritonLLVMOpBuilder(loc, rewriter);
610618
std::string fnName = "intel_sub_group_2d_block_prefetch_";
611619
fnName += std::to_string(op.getElemSizeInBits()) + "b_" +
612620
std::to_string(op.getTileHeight()) + "r" +
613621
std::to_string(op.getTileWidth()) + "x" +
614622
std::to_string(op.getVBlocks()) + "c";
615623
fnName = "_Z" + std::to_string(fnName.size()) + fnName + "PU3AS1viiiDv2_i";
616624
VectorType vecType = vec_ty(i32_ty, 2);
617-
Value byteCoord = insert_element(
618-
vecType, insert_element(vecType, undef(vecType), op.getX(), i32_val(0)),
619-
op.getY(), i32_val(1));
625+
Value byteCoord = b.insert_element(
626+
vecType,
627+
b.insert_element(vecType, b.undef(vecType), op.getX(), b.i32_val(0)),
628+
op.getY(), b.i32_val(1));
620629
SmallVector<Type> argTypes{ptr_ty(ctx, 1), i32_ty, i32_ty, i32_ty, vecType};
621630
SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
622631
op.getBasePitch(), byteCoord};

third_party/intel/lib/TritonIntelGPUToLLVM/BF16Casts.cpp

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ struct TruncBF16 : ConvertOpToLLVMPattern<arith::TruncFOp> {
7575
namespace mlir::triton::intel {
7676
Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
7777
Value v) {
78+
auto b = TritonLLVMOpBuilder(loc, rewriter);
7879
if (auto definingOp = v.getDefiningOp()) {
7980
auto moduleOp = definingOp->getParentWithTrait<OpTrait::SymbolTable>();
8081
if (moduleOp->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
@@ -86,19 +87,20 @@ Value convertBf16ToFp32(Location loc, ConversionPatternRewriter &rewriter,
8687
auto ext_func = triton::gpu::intel::lookupOrCreateSPIRVFn(moduleOp, name,
8788
inTy, outTy);
8889
auto call = triton::gpu::intel::createSPIRVBuiltinCall(
89-
loc, rewriter, ext_func, bitcast(v, inTy).getResult());
90+
loc, rewriter, ext_func, b.bitcast(v, inTy).getResult());
9091
return call.getResult();
9192
}
9293
}
9394

94-
auto as_int16 = bitcast(v, i16_ty);
95-
auto as_int32 = zext(i32_ty, as_int16);
96-
auto shifted = shl(i32_ty, as_int32, i32_val(16));
97-
return (bitcast(shifted, f32_ty));
95+
auto as_int16 = b.bitcast(v, i16_ty);
96+
auto as_int32 = b.zext(i32_ty, as_int16);
97+
auto shifted = b.shl(i32_ty, as_int32, b.i32_val(16));
98+
return (b.bitcast(shifted, f32_ty));
9899
}
99100

100101
Value convertFp32ToBf16(Location loc, ConversionPatternRewriter &rewriter,
101102
Value v, RoundingMode rounding) {
103+
auto b = TritonLLVMOpBuilder(loc, rewriter);
102104
if (auto definingOp = v.getDefiningOp()) {
103105
auto moduleOp = definingOp->getParentWithTrait<OpTrait::SymbolTable>();
104106
if (moduleOp->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
@@ -114,36 +116,36 @@ Value convertFp32ToBf16(Location loc, ConversionPatternRewriter &rewriter,
114116
moduleOp, name, inTy, funcOutTy);
115117
auto call = triton::gpu::intel::createSPIRVBuiltinCall(loc, rewriter,
116118
trunc_func, v);
117-
return bitcast(call.getResult(), outTy);
119+
return b.bitcast(call.getResult(), outTy);
118120
}
119121
}
120122

121123
assert(!isa<VectorType>(v.getType()) && "Not yet supported");
122124

123-
auto as_uint32 = bitcast(v, i32_ty);
125+
auto as_uint32 = b.bitcast(v, i32_ty);
124126
auto check_exponent =
125-
and_(i32_ty, xor_(i32_ty, as_uint32, i32_val(0xffffffff)),
126-
i32_val(0x7f800000));
127-
auto exponent_not_all1s = icmp_ne(check_exponent, i32_val(0));
128-
auto exponent_all1s = icmp_eq(check_exponent, i32_val(0));
127+
b.and_(i32_ty, b.xor_(i32_ty, as_uint32, b.i32_val(0xffffffff)),
128+
b.i32_val(0x7f800000));
129+
auto exponent_not_all1s = b.icmp_ne(check_exponent, b.i32_val(0));
130+
auto exponent_all1s = b.icmp_eq(check_exponent, b.i32_val(0));
129131
Value rounded = as_uint32;
130132
if (rounding == RoundingMode::RTNE) {
131-
rounded =
132-
add(i32_ty, i32_val(0x7fff),
133-
and_(i32_ty, lshr(i32_ty, as_uint32, i32_val(16)), i32_val(1)));
134-
rounded = add(i32_ty, rounded, as_uint32);
135-
rounded = select(exponent_not_all1s, rounded, as_uint32);
133+
rounded = b.add(
134+
i32_ty, b.i32_val(0x7fff),
135+
b.and_(i32_ty, b.lshr(i32_ty, as_uint32, b.i32_val(16)), b.i32_val(1)));
136+
rounded = b.add(i32_ty, rounded, as_uint32);
137+
rounded = b.select(exponent_not_all1s, rounded, as_uint32);
136138
}
137139

138-
auto preserve_nan =
139-
and_(i1_ty, exponent_all1s,
140-
icmp_ne(and_(i32_ty, as_uint32, i32_val(0xffff)), i32_val(0)));
141-
auto nan = or_(i32_ty, as_uint32, i32_val(0x10000));
142-
Value res = select(preserve_nan, nan, rounded);
140+
auto preserve_nan = b.and_(
141+
i1_ty, exponent_all1s,
142+
b.icmp_ne(b.and_(i32_ty, as_uint32, b.i32_val(0xffff)), b.i32_val(0)));
143+
auto nan = b.or_(i32_ty, as_uint32, b.i32_val(0x10000));
144+
Value res = b.select(preserve_nan, nan, rounded);
143145

144-
auto shifted = lshr(i32_ty, res, i32_val(16));
145-
auto truncated = trunc(i16_ty, shifted);
146-
return bitcast(truncated, bf16_ty);
146+
auto shifted = b.lshr(i32_ty, res, b.i32_val(16));
147+
auto truncated = b.trunc(i16_ty, shifted);
148+
return b.bitcast(truncated, bf16_ty);
147149
}
148150

149151
void populateBF16CastsLLVMPatterns(LLVMTypeConverter &typeConverter,

third_party/intel/lib/TritonIntelGPUToLLVM/ControlFlowOpToLLVM.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@ struct ReturnOpConversion
3838
Value packedResults =
3939
rewriter.create<LLVM::UndefOp>(op.getLoc(), packedResultsTy);
4040
auto loc = op.getLoc();
41+
auto b = TritonLLVMOpBuilder(loc, rewriter);
4142
for (auto it : llvm::enumerate(adaptor.getOperands())) {
42-
packedResults = insert_val(packedResultsTy, packedResults, it.value(),
43-
it.index());
43+
packedResults = b.insert_val(packedResultsTy, packedResults,
44+
it.value(), it.index());
4445
}
4546
newOp = rewriter.create<LLVM::ReturnOp>(op.getLoc(), packedResults);
4647
}

0 commit comments

Comments
 (0)