@@ -20,7 +20,10 @@ namespace tptr {
2020#define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE " ]: " )
2121#define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
2222
23- // Build standard MemRef LLVM struct type
23+ static bool isOneToOneCast (UnrealizedConversionCastOp op) {
24+ return (op.getInputs ().size () == 1 && op->getNumResults () == 1 );
25+ }
26+
2427Type convertMemRefType (MemRefType type) {
2528 auto ctx = type.getContext ();
2629 auto rank = type.getShape ().size ();
@@ -41,31 +44,11 @@ Type convertMemRefType(MemRefType type) {
4144}
4245
4346// PtrAddOp -> llvm.getelementptr conversion
44- class PtrAddConverter : public OpConversionPattern <tptr::PtrAddOp> {
47+ struct PtrAddConverter : OpConversionPattern<tptr::PtrAddOp> {
4548 using OpConversionPattern<tptr::PtrAddOp>::OpConversionPattern;
4649
47- Type convertPtrPointerType (ptr::PtrType type,
48- Type elemTy = nullptr ) const {
50+ Type convertPtrPointerType (ptr::PtrType type) const {
4951 auto ctx = type.getContext ();
50- auto pointeeType = elemTy ? elemTy : type.getElementType ();
51-
52- if (isa<RankedTensorType>(pointeeType)) {
53- // struct {
54- // ptr base_ptr;
55- // array<rank x i64> offsets;
56- // array<rank x i64> shape;
57- // array<rank x i64> strides;
58- // }
59- auto tensorTy = cast<RankedTensorType>(pointeeType);
60- auto rank = tensorTy.getShape ().size ();
61- auto i64Ty = IntegerType::get (ctx, 64 );
62- SmallVector<Type, 4 > types{LLVM::LLVMPointerType::get (ctx),
63- LLVM::LLVMArrayType::get (ctx, i64Ty, rank),
64- LLVM::LLVMArrayType::get (ctx, i64Ty, rank),
65- LLVM::LLVMArrayType::get (ctx, i64Ty, rank)};
66- return LLVM::LLVMStructType::getLiteral (ctx, types);
67- }
68-
6952 return LLVM::LLVMPointerType::get (ctx);
7053 }
7154
@@ -96,15 +79,12 @@ class PtrAddConverter : public OpConversionPattern<tptr::PtrAddOp> {
9679 elemTy = rewriter.getIntegerType (8 ); // default to i8
9780 }
9881
99- Type resTy = convertPtrPointerType (ptrTy, elemTy );
82+ Type resTy = convertPtrPointerType (ptrTy);
10083
101- // Critical fix: extract element index from byte offset
102- // offset = count * element_size, we need count
10384 Value elementIndex;
10485 if (auto mulOp = adaptor.getOffset ().getDefiningOp <LLVM::MulOp>()) {
10586 elementIndex = mulOp.getLhs ();
10687 } else {
107- // Warning: cannot recognize offset pattern, using raw offset
10888 LDBG (" Warning: ptradd offset is not MulOp pattern, using raw offset" );
10989 elementIndex = adaptor.getOffset ();
11090 }
@@ -118,7 +98,7 @@ class PtrAddConverter : public OpConversionPattern<tptr::PtrAddOp> {
11898};
11999
120100// ToMemrefOp -> build LLVM memref struct
121- class ToMemrefConverter : public OpConversionPattern <tptr::ToMemrefOp> {
101+ struct ToMemrefConverter : OpConversionPattern<tptr::ToMemrefOp> {
122102 using OpConversionPattern<tptr::ToMemrefOp>::OpConversionPattern;
123103
124104 LogicalResult
@@ -151,7 +131,6 @@ class ToMemrefConverter : public OpConversionPattern<tptr::ToMemrefOp> {
151131 auto shape = memrefType.getShape ();
152132 auto rank = shape.size ();
153133
154- // Build memref struct
155134 Value result = rewriter.create <LLVM::UndefOp>(loc, targetType);
156135 result =
157136 rewriter.create <LLVM::InsertValueOp>(loc, result, input, 0 ); // base_ptr
@@ -162,15 +141,13 @@ class ToMemrefConverter : public OpConversionPattern<tptr::ToMemrefOp> {
162141 loc, i64Ty, rewriter.getIntegerAttr (i64Ty, 0 ));
163142 result = rewriter.create <LLVM::InsertValueOp>(loc, result, zeroOffset, 2 );
164143
165- // Calculate row-major layout strides
166144 SmallVector<int64_t > strides (rank, 1 );
167145 for (int i = rank - 2 ; i >= 0 ; --i) {
168146 if (shape[i + 1 ] != ShapedType::kDynamic ) {
169147 strides[i] = strides[i + 1 ] * shape[i + 1 ];
170148 }
171149 }
172150
173- // Set sizes and strides
174151 for (auto [i, size] : llvm::enumerate (shape)) {
175152 Value sizeVal = rewriter.create <LLVM::ConstantOp>(
176153 loc, i64Ty, rewriter.getIntegerAttr (i64Ty, size));
@@ -191,7 +168,7 @@ class ToMemrefConverter : public OpConversionPattern<tptr::ToMemrefOp> {
191168};
192169
193170// FromMemrefOp -> llvm.extractvalue
194- class FromMemrefConverter : public OpConversionPattern <tptr::FromMemrefOp> {
171+ struct FromMemrefConverter : OpConversionPattern<tptr::FromMemrefOp> {
195172 using OpConversionPattern<tptr::FromMemrefOp>::OpConversionPattern;
196173
197174 LogicalResult
@@ -221,8 +198,8 @@ class FromMemrefConverter : public OpConversionPattern<tptr::FromMemrefOp> {
221198};
222199
223200// Clean up unused UnrealizedConversionCast
224- class UnrealizedCastConverter
225- : public OpConversionPattern<UnrealizedConversionCastOp> {
201+ struct UnrealizedCastConverter
202+ : OpConversionPattern<UnrealizedConversionCastOp> {
226203 using OpConversionPattern<UnrealizedConversionCastOp>::OpConversionPattern;
227204
228205 LogicalResult
@@ -242,7 +219,6 @@ class UnrealizedCastConverter
242219 return success ();
243220 }
244221
245- // Reject unsafe conversions
246222 if (isa<ptr::PtrType>(outputType) ||
247223 (isa<LLVM::LLVMPointerType>(inputType) &&
248224 isa<ptr::PtrType>(outputType)) ||
@@ -251,7 +227,6 @@ class UnrealizedCastConverter
251227 return rewriter.notifyMatchFailure (op, " unsafe pointer conversion" );
252228 }
253229
254- // Allowed safe conversions
255230 if ((isa<LLVM::LLVMStructType>(inputType) && isa<MemRefType>(outputType)) ||
256231 (isa<MemRefType>(inputType) && isa<LLVM::LLVMStructType>(outputType))) {
257232 LDBG (" matchAndRewrite: replace with input: " << op << " -> " << input);
@@ -290,7 +265,7 @@ static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
290265}
291266
292267// Conditional branch conversion
293- class ConvertControlFlowOp : public OpConversionPattern <cf::CondBranchOp> {
268+ struct ConvertControlFlowOp : OpConversionPattern<cf::CondBranchOp> {
294269 using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
295270
296271 LogicalResult
@@ -316,7 +291,7 @@ class ConvertControlFlowOp : public OpConversionPattern<cf::CondBranchOp> {
316291};
317292
318293// Unconditional branch conversion
319- class ConvertBranchOp : public OpConversionPattern <cf::BranchOp> {
294+ struct ConvertBranchOp : OpConversionPattern<cf::BranchOp> {
320295 using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
321296
322297 LogicalResult
@@ -338,7 +313,7 @@ class ConvertBranchOp : public OpConversionPattern<cf::BranchOp> {
338313
339314// MemRef allocation with pointer element types -> LLVM malloc + struct
340315// construction
341- class MemRefAllocConverter : public OpConversionPattern <memref::AllocOp> {
316+ struct MemRefAllocConverter : OpConversionPattern<memref::AllocOp> {
342317 using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
343318
344319 LogicalResult
@@ -421,7 +396,7 @@ class MemRefAllocConverter : public OpConversionPattern<memref::AllocOp> {
421396};
422397
423398// MemRef store with pointer element types -> LLVM GEP + store
424- class MemRefStoreConverter : public OpConversionPattern <memref::StoreOp> {
399+ struct MemRefStoreConverter : OpConversionPattern<memref::StoreOp> {
425400 using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
426401
427402 LogicalResult
@@ -456,7 +431,9 @@ class MemRefStoreConverter : public OpConversionPattern<memref::StoreOp> {
456431 // Convert index to i64 if needed
457432 if (index.getType () != i64Ty) {
458433 if (isa<IndexType>(index.getType ())) {
459- index = rewriter.create <UnrealizedConversionCastOp>(loc, i64Ty, index).getResult (0 );
434+ index =
435+ rewriter.create <UnrealizedConversionCastOp>(loc, i64Ty, index)
436+ .getResult (0 );
460437 }
461438 }
462439 linearIndex = index;
@@ -470,21 +447,25 @@ class MemRefStoreConverter : public OpConversionPattern<memref::StoreOp> {
470447 Value convertedIndex = index;
471448 if (index.getType () != i64Ty) {
472449 if (isa<IndexType>(index.getType ())) {
473- convertedIndex = rewriter.create <UnrealizedConversionCastOp>(loc, i64Ty, index).getResult (0 );
450+ convertedIndex =
451+ rewriter.create <UnrealizedConversionCastOp>(loc, i64Ty, index)
452+ .getResult (0 );
474453 }
475454 }
476455
477456 Value stride = rewriter.create <LLVM::ExtractValueOp>(
478457 loc, i64Ty, memrefDescriptor,
479458 rewriter.getDenseI64ArrayAttr ({4 , static_cast <int64_t >(i)}));
480- Value contribution = rewriter.create <LLVM::MulOp>(loc, convertedIndex, stride);
481- linearIndex = rewriter.create <LLVM::AddOp>(loc, linearIndex, contribution);
459+ Value contribution =
460+ rewriter.create <LLVM::MulOp>(loc, convertedIndex, stride);
461+ linearIndex =
462+ rewriter.create <LLVM::AddOp>(loc, linearIndex, contribution);
482463 }
483464 }
484465
485466 // GEP to get the address of the element
486- Value elementPtr = rewriter. create <LLVM::GEPOp>(
487- loc, ptrTy, ptrTy, basePtr, linearIndex);
467+ Value elementPtr =
468+ rewriter. create <LLVM::GEPOp>( loc, ptrTy, ptrTy, basePtr, linearIndex);
488469
489470 // Store the value
490471 rewriter.create <LLVM::StoreOp>(loc, adaptor.getValue (), elementPtr);
@@ -499,7 +480,7 @@ class MemRefStoreConverter : public OpConversionPattern<memref::StoreOp> {
499480};
500481
501482// MemRef load with pointer element types -> LLVM GEP + load
502- class MemRefLoadConverter : public OpConversionPattern <memref::LoadOp> {
483+ struct MemRefLoadConverter : OpConversionPattern<memref::LoadOp> {
503484 using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
504485
505486 LogicalResult
@@ -540,7 +521,9 @@ class MemRefLoadConverter : public OpConversionPattern<memref::LoadOp> {
540521 // Convert index to i64 if needed
541522 if (index.getType () != i64Ty) {
542523 if (isa<IndexType>(index.getType ())) {
543- index = rewriter.create <UnrealizedConversionCastOp>(loc, i64Ty, index).getResult (0 );
524+ index =
525+ rewriter.create <UnrealizedConversionCastOp>(loc, i64Ty, index)
526+ .getResult (0 );
544527 }
545528 }
546529 linearIndex = index;
@@ -554,24 +537,29 @@ class MemRefLoadConverter : public OpConversionPattern<memref::LoadOp> {
554537 Value convertedIndex = index;
555538 if (index.getType () != i64Ty) {
556539 if (isa<IndexType>(index.getType ())) {
557- convertedIndex = rewriter.create <UnrealizedConversionCastOp>(loc, i64Ty, index).getResult (0 );
540+ convertedIndex =
541+ rewriter.create <UnrealizedConversionCastOp>(loc, i64Ty, index)
542+ .getResult (0 );
558543 }
559544 }
560545
561546 Value stride = rewriter.create <LLVM::ExtractValueOp>(
562547 loc, i64Ty, memrefDescriptor,
563548 rewriter.getDenseI64ArrayAttr ({4 , static_cast <int64_t >(i)}));
564- Value contribution = rewriter.create <LLVM::MulOp>(loc, convertedIndex, stride);
565- linearIndex = rewriter.create <LLVM::AddOp>(loc, linearIndex, contribution);
549+ Value contribution =
550+ rewriter.create <LLVM::MulOp>(loc, convertedIndex, stride);
551+ linearIndex =
552+ rewriter.create <LLVM::AddOp>(loc, linearIndex, contribution);
566553 }
567554 }
568555
569556 // GEP to get the address of the element
570- Value elementPtr = rewriter. create <LLVM::GEPOp>(
571- loc, ptrTy, ptrTy, basePtr, linearIndex);
557+ Value elementPtr =
558+ rewriter. create <LLVM::GEPOp>( loc, ptrTy, ptrTy, basePtr, linearIndex);
572559
573560 // Load the value
574- Value loadedValue = rewriter.create <LLVM::LoadOp>(loc, newResultType, elementPtr);
561+ Value loadedValue =
562+ rewriter.create <LLVM::LoadOp>(loc, newResultType, elementPtr);
575563 rewriter.replaceOp (op, loadedValue);
576564
577565 LDBG (" matchAndRewrite: memref.load done -> LLVM GEP + load" );
@@ -583,15 +571,24 @@ class MemRefLoadConverter : public OpConversionPattern<memref::LoadOp> {
583571};
584572
585573// TypeOffsetOp -> constant conversion
586- class TypeOffsetConverter : public OpConversionPattern <tptr::TypeOffsetOp> {
574+ struct TypeOffsetConverter : OpConversionPattern<tptr::TypeOffsetOp> {
587575 using OpConversionPattern<tptr::TypeOffsetOp>::OpConversionPattern;
588576
577+ llvm::TypeSize
578+ getTypeSize (tptr::TypeOffsetOp op,
579+ std::optional<DataLayout> layout = std::nullopt ) const {
580+ if (layout)
581+ return layout->getTypeSize (op.getBaseType ());
582+ DataLayout dl = DataLayout::closest (op);
583+ return dl.getTypeSize (op.getBaseType ());
584+ }
585+
589586 LogicalResult
590587 matchAndRewrite (tptr::TypeOffsetOp op, OpAdaptor adaptor,
591588 ConversionPatternRewriter &rewriter) const override {
592589 LDBG (" matchAndRewrite: type_offset " << op);
593590
594- auto size = op. getTypeSize ();
591+ auto size = getTypeSize (op );
595592 auto constOp = rewriter.create <LLVM::ConstantOp>(
596593 op.getLoc (), op.getType (), rewriter.getIntegerAttr (op.getType (), size));
597594
0 commit comments