Skip to content

Commit 934f936

Browse files
committed
review changes
1 parent 7ab6260 commit 934f936

File tree

7 files changed

+103
-200
lines changed

7 files changed

+103
-200
lines changed

include/triton-shared/Conversion/TPtrToLLVM/TPtrToLLVM.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@ void populateTPtrToLLVMConversionPatterns(RewritePatternSet &patterns,
1616

1717
std::unique_ptr<OperationPass<ModuleOp>> createTPtrToLLVMPass();
1818

19-
static bool isOneToOneCast(UnrealizedConversionCastOp op) {
20-
return (op.getInputs().size() == 1 && op->getNumResults() == 1);
21-
}
22-
2319
} // namespace tptr
2420
} // namespace mlir
2521

include/triton-shared/Dialect/TPtr/IR/TPtrDialect.td

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,6 @@ def TPTR_TypeOffsetOp : TPTR_Op<"type_offset", [ConstantLike, Pure]> {
116116
attr-dict $baseType custom<IntType>(type($result))
117117
}];
118118
let hasFolder = 1;
119-
let extraClassDeclaration = [{
120-
/// Returns the type offset according to `layout`. If `layout` is `nullopt`
121-
/// the nearest layout the op will be used for the computation.
122-
llvm::TypeSize getTypeSize(std::optional<DataLayout> layout = std::nullopt);
123-
}];
124119
}
125120

126121
def TPTR_FromMemrefOp : TPTR_Op<"from_memref", [Pure]> {

lib/Conversion/TPtrToLLVM/TPtrToLLVM.cpp

Lines changed: 53 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ namespace tptr {
2020
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
2121
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
2222

23-
// Build standard MemRef LLVM struct type
23+
static bool isOneToOneCast(UnrealizedConversionCastOp op) {
24+
return (op.getInputs().size() == 1 && op->getNumResults() == 1);
25+
}
26+
2427
Type convertMemRefType(MemRefType type) {
2528
auto ctx = type.getContext();
2629
auto rank = type.getShape().size();
@@ -41,31 +44,11 @@ Type convertMemRefType(MemRefType type) {
4144
}
4245

4346
// PtrAddOp -> llvm.getelementptr conversion
44-
class PtrAddConverter : public OpConversionPattern<tptr::PtrAddOp> {
47+
struct PtrAddConverter : OpConversionPattern<tptr::PtrAddOp> {
4548
using OpConversionPattern<tptr::PtrAddOp>::OpConversionPattern;
4649

47-
Type convertPtrPointerType(ptr::PtrType type,
48-
Type elemTy = nullptr) const {
50+
Type convertPtrPointerType(ptr::PtrType type) const {
4951
auto ctx = type.getContext();
50-
auto pointeeType = elemTy ? elemTy : type.getElementType();
51-
52-
if (isa<RankedTensorType>(pointeeType)) {
53-
// struct {
54-
// ptr base_ptr;
55-
// array<rank x i64> offsets;
56-
// array<rank x i64> shape;
57-
// array<rank x i64> strides;
58-
// }
59-
auto tensorTy = cast<RankedTensorType>(pointeeType);
60-
auto rank = tensorTy.getShape().size();
61-
auto i64Ty = IntegerType::get(ctx, 64);
62-
SmallVector<Type, 4> types{LLVM::LLVMPointerType::get(ctx),
63-
LLVM::LLVMArrayType::get(ctx, i64Ty, rank),
64-
LLVM::LLVMArrayType::get(ctx, i64Ty, rank),
65-
LLVM::LLVMArrayType::get(ctx, i64Ty, rank)};
66-
return LLVM::LLVMStructType::getLiteral(ctx, types);
67-
}
68-
6952
return LLVM::LLVMPointerType::get(ctx);
7053
}
7154

@@ -96,15 +79,12 @@ class PtrAddConverter : public OpConversionPattern<tptr::PtrAddOp> {
9679
elemTy = rewriter.getIntegerType(8); // default to i8
9780
}
9881

99-
Type resTy = convertPtrPointerType(ptrTy, elemTy);
82+
Type resTy = convertPtrPointerType(ptrTy);
10083

101-
// Critical fix: extract element index from byte offset
102-
// offset = count * element_size, we need count
10384
Value elementIndex;
10485
if (auto mulOp = adaptor.getOffset().getDefiningOp<LLVM::MulOp>()) {
10586
elementIndex = mulOp.getLhs();
10687
} else {
107-
// Warning: cannot recognize offset pattern, using raw offset
10888
LDBG("Warning: ptradd offset is not MulOp pattern, using raw offset");
10989
elementIndex = adaptor.getOffset();
11090
}
@@ -118,7 +98,7 @@ class PtrAddConverter : public OpConversionPattern<tptr::PtrAddOp> {
11898
};
11999

120100
// ToMemrefOp -> build LLVM memref struct
121-
class ToMemrefConverter : public OpConversionPattern<tptr::ToMemrefOp> {
101+
struct ToMemrefConverter : OpConversionPattern<tptr::ToMemrefOp> {
122102
using OpConversionPattern<tptr::ToMemrefOp>::OpConversionPattern;
123103

124104
LogicalResult
@@ -151,7 +131,6 @@ class ToMemrefConverter : public OpConversionPattern<tptr::ToMemrefOp> {
151131
auto shape = memrefType.getShape();
152132
auto rank = shape.size();
153133

154-
// Build memref struct
155134
Value result = rewriter.create<LLVM::UndefOp>(loc, targetType);
156135
result =
157136
rewriter.create<LLVM::InsertValueOp>(loc, result, input, 0); // base_ptr
@@ -162,15 +141,13 @@ class ToMemrefConverter : public OpConversionPattern<tptr::ToMemrefOp> {
162141
loc, i64Ty, rewriter.getIntegerAttr(i64Ty, 0));
163142
result = rewriter.create<LLVM::InsertValueOp>(loc, result, zeroOffset, 2);
164143

165-
// Calculate row-major layout strides
166144
SmallVector<int64_t> strides(rank, 1);
167145
for (int i = rank - 2; i >= 0; --i) {
168146
if (shape[i + 1] != ShapedType::kDynamic) {
169147
strides[i] = strides[i + 1] * shape[i + 1];
170148
}
171149
}
172150

173-
// Set sizes and strides
174151
for (auto [i, size] : llvm::enumerate(shape)) {
175152
Value sizeVal = rewriter.create<LLVM::ConstantOp>(
176153
loc, i64Ty, rewriter.getIntegerAttr(i64Ty, size));
@@ -191,7 +168,7 @@ class ToMemrefConverter : public OpConversionPattern<tptr::ToMemrefOp> {
191168
};
192169

193170
// FromMemrefOp -> llvm.extractvalue
194-
class FromMemrefConverter : public OpConversionPattern<tptr::FromMemrefOp> {
171+
struct FromMemrefConverter : OpConversionPattern<tptr::FromMemrefOp> {
195172
using OpConversionPattern<tptr::FromMemrefOp>::OpConversionPattern;
196173

197174
LogicalResult
@@ -221,8 +198,8 @@ class FromMemrefConverter : public OpConversionPattern<tptr::FromMemrefOp> {
221198
};
222199

223200
// Clean up unused UnrealizedConversionCast
224-
class UnrealizedCastConverter
225-
: public OpConversionPattern<UnrealizedConversionCastOp> {
201+
struct UnrealizedCastConverter
202+
: OpConversionPattern<UnrealizedConversionCastOp> {
226203
using OpConversionPattern<UnrealizedConversionCastOp>::OpConversionPattern;
227204

228205
LogicalResult
@@ -242,7 +219,6 @@ class UnrealizedCastConverter
242219
return success();
243220
}
244221

245-
// Reject unsafe conversions
246222
if (isa<ptr::PtrType>(outputType) ||
247223
(isa<LLVM::LLVMPointerType>(inputType) &&
248224
isa<ptr::PtrType>(outputType)) ||
@@ -251,7 +227,6 @@ class UnrealizedCastConverter
251227
return rewriter.notifyMatchFailure(op, "unsafe pointer conversion");
252228
}
253229

254-
// Allowed safe conversions
255230
if ((isa<LLVM::LLVMStructType>(inputType) && isa<MemRefType>(outputType)) ||
256231
(isa<MemRefType>(inputType) && isa<LLVM::LLVMStructType>(outputType))) {
257232
LDBG("matchAndRewrite: replace with input: " << op << " -> " << input);
@@ -290,7 +265,7 @@ static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
290265
}
291266

292267
// Conditional branch conversion
293-
class ConvertControlFlowOp : public OpConversionPattern<cf::CondBranchOp> {
268+
struct ConvertControlFlowOp : OpConversionPattern<cf::CondBranchOp> {
294269
using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
295270

296271
LogicalResult
@@ -316,7 +291,7 @@ class ConvertControlFlowOp : public OpConversionPattern<cf::CondBranchOp> {
316291
};
317292

318293
// Unconditional branch conversion
319-
class ConvertBranchOp : public OpConversionPattern<cf::BranchOp> {
294+
struct ConvertBranchOp : OpConversionPattern<cf::BranchOp> {
320295
using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
321296

322297
LogicalResult
@@ -338,7 +313,7 @@ class ConvertBranchOp : public OpConversionPattern<cf::BranchOp> {
338313

339314
// MemRef allocation with pointer element types -> LLVM malloc + struct
340315
// construction
341-
class MemRefAllocConverter : public OpConversionPattern<memref::AllocOp> {
316+
struct MemRefAllocConverter : OpConversionPattern<memref::AllocOp> {
342317
using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
343318

344319
LogicalResult
@@ -421,7 +396,7 @@ class MemRefAllocConverter : public OpConversionPattern<memref::AllocOp> {
421396
};
422397

423398
// MemRef store with pointer element types -> LLVM GEP + store
424-
class MemRefStoreConverter : public OpConversionPattern<memref::StoreOp> {
399+
struct MemRefStoreConverter : OpConversionPattern<memref::StoreOp> {
425400
using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
426401

427402
LogicalResult
@@ -456,7 +431,9 @@ class MemRefStoreConverter : public OpConversionPattern<memref::StoreOp> {
456431
// Convert index to i64 if needed
457432
if (index.getType() != i64Ty) {
458433
if (isa<IndexType>(index.getType())) {
459-
index = rewriter.create<UnrealizedConversionCastOp>(loc, i64Ty, index).getResult(0);
434+
index =
435+
rewriter.create<UnrealizedConversionCastOp>(loc, i64Ty, index)
436+
.getResult(0);
460437
}
461438
}
462439
linearIndex = index;
@@ -470,21 +447,25 @@ class MemRefStoreConverter : public OpConversionPattern<memref::StoreOp> {
470447
Value convertedIndex = index;
471448
if (index.getType() != i64Ty) {
472449
if (isa<IndexType>(index.getType())) {
473-
convertedIndex = rewriter.create<UnrealizedConversionCastOp>(loc, i64Ty, index).getResult(0);
450+
convertedIndex =
451+
rewriter.create<UnrealizedConversionCastOp>(loc, i64Ty, index)
452+
.getResult(0);
474453
}
475454
}
476455

477456
Value stride = rewriter.create<LLVM::ExtractValueOp>(
478457
loc, i64Ty, memrefDescriptor,
479458
rewriter.getDenseI64ArrayAttr({4, static_cast<int64_t>(i)}));
480-
Value contribution = rewriter.create<LLVM::MulOp>(loc, convertedIndex, stride);
481-
linearIndex = rewriter.create<LLVM::AddOp>(loc, linearIndex, contribution);
459+
Value contribution =
460+
rewriter.create<LLVM::MulOp>(loc, convertedIndex, stride);
461+
linearIndex =
462+
rewriter.create<LLVM::AddOp>(loc, linearIndex, contribution);
482463
}
483464
}
484465

485466
// GEP to get the address of the element
486-
Value elementPtr = rewriter.create<LLVM::GEPOp>(
487-
loc, ptrTy, ptrTy, basePtr, linearIndex);
467+
Value elementPtr =
468+
rewriter.create<LLVM::GEPOp>(loc, ptrTy, ptrTy, basePtr, linearIndex);
488469

489470
// Store the value
490471
rewriter.create<LLVM::StoreOp>(loc, adaptor.getValue(), elementPtr);
@@ -499,7 +480,7 @@ class MemRefStoreConverter : public OpConversionPattern<memref::StoreOp> {
499480
};
500481

501482
// MemRef load with pointer element types -> LLVM GEP + load
502-
class MemRefLoadConverter : public OpConversionPattern<memref::LoadOp> {
483+
struct MemRefLoadConverter : OpConversionPattern<memref::LoadOp> {
503484
using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
504485

505486
LogicalResult
@@ -540,7 +521,9 @@ class MemRefLoadConverter : public OpConversionPattern<memref::LoadOp> {
540521
// Convert index to i64 if needed
541522
if (index.getType() != i64Ty) {
542523
if (isa<IndexType>(index.getType())) {
543-
index = rewriter.create<UnrealizedConversionCastOp>(loc, i64Ty, index).getResult(0);
524+
index =
525+
rewriter.create<UnrealizedConversionCastOp>(loc, i64Ty, index)
526+
.getResult(0);
544527
}
545528
}
546529
linearIndex = index;
@@ -554,24 +537,29 @@ class MemRefLoadConverter : public OpConversionPattern<memref::LoadOp> {
554537
Value convertedIndex = index;
555538
if (index.getType() != i64Ty) {
556539
if (isa<IndexType>(index.getType())) {
557-
convertedIndex = rewriter.create<UnrealizedConversionCastOp>(loc, i64Ty, index).getResult(0);
540+
convertedIndex =
541+
rewriter.create<UnrealizedConversionCastOp>(loc, i64Ty, index)
542+
.getResult(0);
558543
}
559544
}
560545

561546
Value stride = rewriter.create<LLVM::ExtractValueOp>(
562547
loc, i64Ty, memrefDescriptor,
563548
rewriter.getDenseI64ArrayAttr({4, static_cast<int64_t>(i)}));
564-
Value contribution = rewriter.create<LLVM::MulOp>(loc, convertedIndex, stride);
565-
linearIndex = rewriter.create<LLVM::AddOp>(loc, linearIndex, contribution);
549+
Value contribution =
550+
rewriter.create<LLVM::MulOp>(loc, convertedIndex, stride);
551+
linearIndex =
552+
rewriter.create<LLVM::AddOp>(loc, linearIndex, contribution);
566553
}
567554
}
568555

569556
// GEP to get the address of the element
570-
Value elementPtr = rewriter.create<LLVM::GEPOp>(
571-
loc, ptrTy, ptrTy, basePtr, linearIndex);
557+
Value elementPtr =
558+
rewriter.create<LLVM::GEPOp>(loc, ptrTy, ptrTy, basePtr, linearIndex);
572559

573560
// Load the value
574-
Value loadedValue = rewriter.create<LLVM::LoadOp>(loc, newResultType, elementPtr);
561+
Value loadedValue =
562+
rewriter.create<LLVM::LoadOp>(loc, newResultType, elementPtr);
575563
rewriter.replaceOp(op, loadedValue);
576564

577565
LDBG("matchAndRewrite: memref.load done -> LLVM GEP + load");
@@ -583,15 +571,24 @@ class MemRefLoadConverter : public OpConversionPattern<memref::LoadOp> {
583571
};
584572

585573
// TypeOffsetOp -> constant conversion
586-
class TypeOffsetConverter : public OpConversionPattern<tptr::TypeOffsetOp> {
574+
struct TypeOffsetConverter : OpConversionPattern<tptr::TypeOffsetOp> {
587575
using OpConversionPattern<tptr::TypeOffsetOp>::OpConversionPattern;
588576

577+
llvm::TypeSize
578+
getTypeSize(tptr::TypeOffsetOp op,
579+
std::optional<DataLayout> layout = std::nullopt) const {
580+
if (layout)
581+
return layout->getTypeSize(op.getBaseType());
582+
DataLayout dl = DataLayout::closest(op);
583+
return dl.getTypeSize(op.getBaseType());
584+
}
585+
589586
LogicalResult
590587
matchAndRewrite(tptr::TypeOffsetOp op, OpAdaptor adaptor,
591588
ConversionPatternRewriter &rewriter) const override {
592589
LDBG("matchAndRewrite: type_offset " << op);
593590

594-
auto size = op.getTypeSize();
591+
auto size = getTypeSize(op);
595592
auto constOp = rewriter.create<LLVM::ConstantOp>(
596593
op.getLoc(), op.getType(), rewriter.getIntegerAttr(op.getType(), size));
597594

0 commit comments

Comments
 (0)