Skip to content

Commit 98c2d04

Browse files
committed
self review and optimizing
1 parent a2683b1 commit 98c2d04

File tree

2 files changed

+49
-66
lines changed

2 files changed

+49
-66
lines changed

lib/Conversion/TPtrToLLVM/TPtrToLLVM.cpp

Lines changed: 44 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include "mlir/IR/BuiltinTypes.h"
99
#include "mlir/IR/PatternMatch.h"
1010
#include "mlir/IR/Value.h"
11+
#include "llvm/Support/Debug.h"
12+
#include "llvm/Support/raw_ostream.h"
1113

1214
#include "mlir/Transforms/DialectConversion.h"
1315
#include "triton-shared/Conversion/TPtrToLLVM/TPtrToLLVM.h"
@@ -24,24 +26,6 @@ static bool isOneToOneCast(UnrealizedConversionCastOp op) {
2426
return (op.getInputs().size() == 1 && op->getNumResults() == 1);
2527
}
2628

27-
Type convertMemRefType(MemRefType type) {
28-
auto ctx = type.getContext();
29-
auto rank = type.getShape().size();
30-
auto i64Ty = IntegerType::get(ctx, 64);
31-
SmallVector<Type, 5> types;
32-
33-
// struct { ptr base_ptr, ptr aligned_ptr, i64 offset,
34-
// array<rank x i64> sizes, array<rank x i64> strides }
35-
types.push_back(LLVM::LLVMPointerType::get(ctx)); // base pointer
36-
types.push_back(LLVM::LLVMPointerType::get(ctx)); // aligned pointer
37-
types.push_back(i64Ty); // offset
38-
types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); // sizes
39-
types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); // strides
40-
41-
Type ret = LLVM::LLVMStructType::getLiteral(ctx, types);
42-
LDBG(" From MemrefConverter convertMemRefType: " << type << " -> " << ret);
43-
return ret;
44-
}
4529

4630
// PtrAddOp -> llvm.getelementptr conversion
4731
struct PtrAddConverter : OpConversionPattern<tptr::PtrAddOp> {
@@ -109,7 +93,7 @@ struct ToMemrefConverter : OpConversionPattern<tptr::ToMemrefOp> {
10993
LogicalResult
11094
matchAndRewrite(tptr::ToMemrefOp op, OpAdaptor adaptor,
11195
ConversionPatternRewriter &rewriter) const override {
112-
LDBG("matchAndRewrite: to_memref " << op);
96+
LDBG("matchAndRewrite: to_memref (before) " << op);
11397

11498
auto input = adaptor.getArg();
11599

@@ -125,15 +109,15 @@ struct ToMemrefConverter : OpConversionPattern<tptr::ToMemrefOp> {
125109
}
126110
}
127111

128-
Type targetType = convertMemRefType(op.getType());
112+
Type targetType = getTypeConverter()->convertType(cast<MemRefType>(op.getType()));
113+
LDBG("matchAndRewrite: to_memref (typeconverted) " << cast<MemRefType>(op.getType()) << " -> " << targetType);
129114
if (!targetType) {
130115
return rewriter.notifyMatchFailure(op, "failed to convert memref type");
131116
}
132117

133118
auto loc = op.getLoc();
134119
auto i64Ty = rewriter.getIntegerType(64);
135-
auto memrefType = cast<MemRefType>(op.getType());
136-
auto shape = memrefType.getShape();
120+
auto shape = cast<MemRefType>(op.getType()).getShape();
137121
auto rank = shape.size();
138122

139123
Value result = rewriter.create<LLVM::UndefOp>(loc, targetType);
@@ -167,7 +151,7 @@ struct ToMemrefConverter : OpConversionPattern<tptr::ToMemrefOp> {
167151
}
168152

169153
rewriter.replaceOp(op, result);
170-
LDBG("matchAndRewrite: to_memref done");
154+
LDBG("matchAndRewrite: to_memref (after) -> " << result);
171155
return success();
172156
}
173157
};
@@ -179,16 +163,17 @@ struct FromMemrefConverter : OpConversionPattern<tptr::FromMemrefOp> {
179163
LogicalResult
180164
matchAndRewrite(tptr::FromMemrefOp op, OpAdaptor adaptor,
181165
ConversionPatternRewriter &rewriter) const override {
182-
LDBG("matchAndRewrite: from_memref " << op);
166+
LDBG("matchAndRewrite: from_memref (before) " << op);
183167

184168
Value input = adaptor.getInput();
185169
if (isa<MemRefType>(input.getType())) {
186-
input =
187-
rewriter
188-
.create<UnrealizedConversionCastOp>(
189-
op.getLoc(),
190-
convertMemRefType(cast<MemRefType>(input.getType())), input)
191-
.getResult(0);
170+
input = rewriter
171+
.create<UnrealizedConversionCastOp>(
172+
op.getLoc(),
173+
getTypeConverter()->convertType(
174+
cast<MemRefType>(input.getType())),
175+
input)
176+
.getResult(0);
192177
}
193178

194179
// Extract base_ptr (index 0)
@@ -197,7 +182,7 @@ struct FromMemrefConverter : OpConversionPattern<tptr::FromMemrefOp> {
197182
op.getLoc(), resultType, input, rewriter.getDenseI64ArrayAttr({0}));
198183

199184
rewriter.replaceOp(op, extractOp);
200-
LDBG("matchAndRewrite: from_memref done");
185+
LDBG("matchAndRewrite: from_memref (after) -> " << extractOp);
201186
return success();
202187
}
203188
};
@@ -225,16 +210,14 @@ struct UnrealizedCastConverter
225210
}
226211

227212
if (isa<ptr::PtrType>(outputType) ||
228-
(isa<LLVM::LLVMPointerType>(inputType) &&
229-
isa<ptr::PtrType>(outputType)) ||
230-
(isa<LLVM::LLVMPointerType>(inputType) &&
231-
isa<MemRefType>(outputType))) {
213+
(isa<LLVM::LLVMPointerType>(inputType) && isa<MemRefType>(outputType))) {
214+
LDBG("UnrealizedCast (reject): unsafe pointer conversion " << op);
232215
return rewriter.notifyMatchFailure(op, "unsafe pointer conversion");
233216
}
234217

235218
if ((isa<LLVM::LLVMStructType>(inputType) && isa<MemRefType>(outputType)) ||
236219
(isa<MemRefType>(inputType) && isa<LLVM::LLVMStructType>(outputType))) {
237-
LDBG("matchAndRewrite: replace with input: " << op << " -> " << input);
220+
LDBG("matchAndRewrite: UnrealizedCast (after) " << op << " -> " << input);
238221
rewriter.replaceOp(op, input);
239222
return success();
240223
}
@@ -276,7 +259,7 @@ struct ConvertControlFlowOp : OpConversionPattern<cf::CondBranchOp> {
276259
LogicalResult
277260
matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
278261
ConversionPatternRewriter &rewriter) const override {
279-
LDBG("matchAndRewrite: cond_branch " << op);
262+
LDBG("matchAndRewrite: cond_branch (before) " << op);
280263

281264
if (failed(legalizeBlockArguments(*op.getTrueDest(), op, rewriter,
282265
*getTypeConverter())) ||
@@ -285,12 +268,12 @@ struct ConvertControlFlowOp : OpConversionPattern<cf::CondBranchOp> {
285268
return failure();
286269
}
287270

288-
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
271+
auto newOp = rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
289272
op, adaptor.getCondition(), op.getTrueDest(),
290273
adaptor.getTrueDestOperands(), op.getFalseDest(),
291274
adaptor.getFalseDestOperands());
292275

293-
LDBG("matchAndRewrite: cond_branch done");
276+
LDBG("matchAndRewrite: cond_branch (after) -> " << newOp);
294277
return success();
295278
}
296279
};
@@ -302,16 +285,17 @@ struct ConvertBranchOp : OpConversionPattern<cf::BranchOp> {
302285
LogicalResult
303286
matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
304287
ConversionPatternRewriter &rewriter) const override {
305-
LDBG("matchAndRewrite: cf.br " << op);
288+
LDBG("matchAndRewrite: cf.br (before) " << op);
306289

307290
if (failed(legalizeBlockArguments(*op.getDest(), op, rewriter,
308291
*getTypeConverter()))) {
309292
return failure();
310293
}
311294

312-
rewriter.replaceOpWithNewOp<cf::BranchOp>(op, op.getDest(),
313-
adaptor.getDestOperands());
314-
LDBG("matchAndRewrite: cf.br done");
295+
auto newOp =
296+
rewriter.replaceOpWithNewOp<cf::BranchOp>(op, op.getDest(),
297+
adaptor.getDestOperands());
298+
LDBG("matchAndRewrite: cf.br (after) -> " << newOp);
315299
return success();
316300
}
317301
};
@@ -324,7 +308,7 @@ struct MemRefAllocConverter : OpConversionPattern<memref::AllocOp> {
324308
LogicalResult
325309
matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
326310
ConversionPatternRewriter &rewriter) const override {
327-
LDBG("matchAndRewrite: memref.alloc " << op);
311+
LDBG("matchAndRewrite: memref.alloc (before) " << op);
328312

329313
auto oldMemRefType = op.getType();
330314
auto elementType = oldMemRefType.getElementType();
@@ -395,7 +379,7 @@ struct MemRefAllocConverter : OpConversionPattern<memref::AllocOp> {
395379
}
396380

397381
rewriter.replaceOp(op, result);
398-
LDBG("matchAndRewrite: memref.alloc done -> LLVM struct");
382+
LDBG("matchAndRewrite: memref.alloc (after) -> " << result);
399383
return success();
400384
}
401385
};
@@ -407,7 +391,7 @@ struct MemRefStoreConverter : OpConversionPattern<memref::StoreOp> {
407391
LogicalResult
408392
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
409393
ConversionPatternRewriter &rewriter) const override {
410-
LDBG("matchAndRewrite: memref.store " << op);
394+
LDBG("matchAndRewrite: memref.store (before) " << op);
411395

412396
auto memrefType = op.getMemRef().getType();
413397
if (auto memrefTy = dyn_cast<MemRefType>(memrefType)) {
@@ -473,10 +457,10 @@ struct MemRefStoreConverter : OpConversionPattern<memref::StoreOp> {
473457
rewriter.create<LLVM::GEPOp>(loc, ptrTy, ptrTy, basePtr, linearIndex);
474458

475459
// Store the value
476-
rewriter.create<LLVM::StoreOp>(loc, adaptor.getValue(), elementPtr);
460+
auto storeOp =
461+
rewriter.create<LLVM::StoreOp>(loc, adaptor.getValue(), elementPtr);
477462
rewriter.eraseOp(op);
478-
479-
LDBG("matchAndRewrite: memref.store done -> LLVM GEP + store");
463+
LDBG("matchAndRewrite: memref.store (after) -> " << storeOp);
480464
return success();
481465
}
482466

@@ -491,7 +475,7 @@ struct MemRefLoadConverter : OpConversionPattern<memref::LoadOp> {
491475
LogicalResult
492476
matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
493477
ConversionPatternRewriter &rewriter) const override {
494-
LDBG("matchAndRewrite: memref.load " << op);
478+
LDBG("matchAndRewrite: memref.load (before) " << op);
495479

496480
auto memrefType = op.getMemRef().getType();
497481
if (auto memrefTy = dyn_cast<MemRefType>(memrefType)) {
@@ -567,7 +551,7 @@ struct MemRefLoadConverter : OpConversionPattern<memref::LoadOp> {
567551
rewriter.create<LLVM::LoadOp>(loc, newResultType, elementPtr);
568552
rewriter.replaceOp(op, loadedValue);
569553

570-
LDBG("matchAndRewrite: memref.load done -> LLVM GEP + load");
554+
LDBG("matchAndRewrite: memref.load (after) -> " << loadedValue);
571555
return success();
572556
}
573557

@@ -591,14 +575,19 @@ struct TypeOffsetConverter : OpConversionPattern<tptr::TypeOffsetOp> {
591575
LogicalResult
592576
matchAndRewrite(tptr::TypeOffsetOp op, OpAdaptor adaptor,
593577
ConversionPatternRewriter &rewriter) const override {
594-
LDBG("matchAndRewrite: type_offset " << op);
578+
LDBG("matchAndRewrite: type_offset (before) " << op);
595579

596580
auto size = getTypeSize(op);
581+
if (size.isScalable()) {
582+
return rewriter.notifyMatchFailure(op, "scalable type size unsupported");
583+
}
584+
auto fixedSize = static_cast<int64_t>(size.getFixedValue());
597585
auto constOp = rewriter.create<LLVM::ConstantOp>(
598-
op.getLoc(), op.getType(), rewriter.getIntegerAttr(op.getType(), size));
586+
op.getLoc(), op.getType(),
587+
rewriter.getIntegerAttr(op.getType(), fixedSize));
599588

600589
rewriter.replaceOp(op, constOp);
601-
LDBG("matchAndRewrite: type_offset done -> " << size);
590+
LDBG("matchAndRewrite: type_offset (after) -> " << constOp);
602591
return success();
603592
}
604593
};
@@ -613,4 +602,4 @@ void populateTPtrToLLVMConversionPatterns(RewritePatternSet &patterns,
613602
}
614603

615604
} // namespace tptr
616-
} // namespace mlir
605+
} // namespace mlir

lib/Conversion/TPtrToLLVM/TPtrToLLVMPass.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "mlir/Transforms/Passes.h"
1313
#include "triton-shared/Conversion/TPtrToLLVM/TPtrToLLVM.h"
1414
#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h"
15-
#include "triton/Dialect/Triton/IR/Types.h"
1615

1716
#define DEBUG_TYPE "tptr-to-llvm"
1817
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -45,11 +44,6 @@ struct TptrToLLVMTypeConverter : TypeConverter {
4544

4645
addConversion([&](MemRefType type) -> std::optional<Type> {
4746
auto elementType = type.getElementType();
48-
49-
if (!isa<ptr::PtrType>(elementType)) {
50-
return std::nullopt;
51-
}
52-
5347
auto ctx = type.getContext();
5448
auto rank = type.getShape().size();
5549
auto i64Ty = IntegerType::get(ctx, 64);
@@ -99,7 +93,7 @@ class TPtrToLLVMPass : public tptr::impl::TPtrToLLVMBase<TPtrToLLVMPass> {
9993
}
10094
for (auto dest : {op.getTrueDest(), op.getFalseDest()}) {
10195
for (auto arg : dest->getArguments()) {
102-
if (isa<triton::PointerType, ptr::PtrType>(arg.getType())) {
96+
if (isa<ptr::PtrType>(arg.getType())) {
10397
LDBG("CondBranchOp marked illegal due to block arg type: "
10498
<< arg.getType());
10599
return false;
@@ -111,15 +105,15 @@ class TPtrToLLVMPass : public tptr::impl::TPtrToLLVMBase<TPtrToLLVMPass> {
111105

112106
target.addDynamicallyLegalOp<cf::BranchOp>([&](cf::BranchOp op) {
113107
for (auto operand : op.getOperands()) {
114-
if (isa<triton::PointerType, ptr::PtrType>(operand.getType())) {
108+
if (isa<ptr::PtrType>(operand.getType())) {
115109
LDBG("BranchOp marked illegal due to operand type: "
116110
<< operand.getType());
117111
return false;
118112
}
119113
}
120114

121115
for (auto arg : op.getDest()->getArguments()) {
122-
if (isa<triton::PointerType, ptr::PtrType>(arg.getType())) {
116+
if (isa<ptr::PtrType>(arg.getType())) {
123117
LDBG("BranchOp marked illegal due to block arg type: "
124118
<< arg.getType());
125119
return false;
@@ -131,12 +125,12 @@ class TPtrToLLVMPass : public tptr::impl::TPtrToLLVMBase<TPtrToLLVMPass> {
131125
target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
132126
[&](UnrealizedConversionCastOp op) {
133127
for (auto type : op.getResultTypes()) {
134-
if (isa<triton::PointerType, ptr::PtrType>(type)) {
128+
if (isa<ptr::PtrType>(type)) {
135129
return false;
136130
}
137131
}
138132
for (auto operand : op.getOperands()) {
139-
if (isa<triton::PointerType, ptr::PtrType>(operand.getType())) {
133+
if (isa<ptr::PtrType>(operand.getType())) {
140134
return false;
141135
}
142136
}

0 commit comments

Comments
 (0)