@@ -20,7 +20,10 @@ namespace tptr {
2020#define  DBGS () (llvm::dbgs() << " ["   DEBUG_TYPE " ]: "  )
2121#define  LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n "  )
2222
23- //  Build standard MemRef LLVM struct type
23+ static  bool  isOneToOneCast (UnrealizedConversionCastOp op) {
24+   return  (op.getInputs ().size () == 1  && op->getNumResults () == 1 );
25+ }
26+ 
2427Type convertMemRefType (MemRefType type) {
2528  auto  ctx = type.getContext ();
2629  auto  rank = type.getShape ().size ();
@@ -41,31 +44,11 @@ Type convertMemRefType(MemRefType type) {
4144}
4245
4346//  PtrAddOp -> llvm.getelementptr conversion
44- class  PtrAddConverter  :  public  OpConversionPattern <tptr::PtrAddOp> {
47+ struct  PtrAddConverter  : OpConversionPattern<tptr::PtrAddOp> {
4548  using  OpConversionPattern<tptr::PtrAddOp>::OpConversionPattern;
4649
47-   Type convertPtrPointerType (ptr::PtrType type,
48-                                 Type elemTy = nullptr ) const  {
50+   Type convertPtrPointerType (ptr::PtrType type) const  {
4951    auto  ctx = type.getContext ();
50-     auto  pointeeType = elemTy ? elemTy : type.getElementType ();
51- 
52-     if  (isa<RankedTensorType>(pointeeType)) {
53-       //  struct {
54-       //    ptr base_ptr;
55-       //    array<rank x i64> offsets;
56-       //    array<rank x i64> shape;
57-       //    array<rank x i64> strides;
58-       //  }
59-       auto  tensorTy = cast<RankedTensorType>(pointeeType);
60-       auto  rank = tensorTy.getShape ().size ();
61-       auto  i64Ty = IntegerType::get (ctx, 64 );
62-       SmallVector<Type, 4 > types{LLVM::LLVMPointerType::get (ctx),
63-                                  LLVM::LLVMArrayType::get (ctx, i64Ty, rank),
64-                                  LLVM::LLVMArrayType::get (ctx, i64Ty, rank),
65-                                  LLVM::LLVMArrayType::get (ctx, i64Ty, rank)};
66-       return  LLVM::LLVMStructType::getLiteral (ctx, types);
67-     }
68- 
6952    return  LLVM::LLVMPointerType::get (ctx);
7053  }
7154
@@ -96,15 +79,12 @@ class PtrAddConverter : public OpConversionPattern<tptr::PtrAddOp> {
9679      elemTy = rewriter.getIntegerType (8 ); //  default to i8
9780    }
9881
99-     Type resTy = convertPtrPointerType (ptrTy, elemTy );
82+     Type resTy = convertPtrPointerType (ptrTy);
10083
101-     //  Critical fix: extract element index from byte offset
102-     //  offset = count * element_size, we need count
10384    Value elementIndex;
10485    if  (auto  mulOp = adaptor.getOffset ().getDefiningOp <LLVM::MulOp>()) {
10586      elementIndex = mulOp.getLhs ();
10687    } else  {
107-       //  Warning: cannot recognize offset pattern, using raw offset
10888      LDBG (" Warning: ptradd offset is not MulOp pattern, using raw offset"  );
10989      elementIndex = adaptor.getOffset ();
11090    }
@@ -118,7 +98,7 @@ class PtrAddConverter : public OpConversionPattern<tptr::PtrAddOp> {
11898};
11999
120100//  ToMemrefOp -> build LLVM memref struct
121- class  ToMemrefConverter  :  public  OpConversionPattern <tptr::ToMemrefOp> {
101+ struct  ToMemrefConverter  : OpConversionPattern<tptr::ToMemrefOp> {
122102  using  OpConversionPattern<tptr::ToMemrefOp>::OpConversionPattern;
123103
124104  LogicalResult
@@ -151,7 +131,6 @@ class ToMemrefConverter : public OpConversionPattern<tptr::ToMemrefOp> {
151131    auto  shape = memrefType.getShape ();
152132    auto  rank = shape.size ();
153133
154-     //  Build memref struct
155134    Value result = rewriter.create <LLVM::UndefOp>(loc, targetType);
156135    result =
157136        rewriter.create <LLVM::InsertValueOp>(loc, result, input, 0 ); //  base_ptr
@@ -162,15 +141,13 @@ class ToMemrefConverter : public OpConversionPattern<tptr::ToMemrefOp> {
162141        loc, i64Ty, rewriter.getIntegerAttr (i64Ty, 0 ));
163142    result = rewriter.create <LLVM::InsertValueOp>(loc, result, zeroOffset, 2 );
164143
165-     //  Calculate row-major layout strides
166144    SmallVector<int64_t > strides (rank, 1 );
167145    for  (int  i = rank - 2 ; i >= 0 ; --i) {
168146      if  (shape[i + 1 ] != ShapedType::kDynamic ) {
169147        strides[i] = strides[i + 1 ] * shape[i + 1 ];
170148      }
171149    }
172150
173-     //  Set sizes and strides
174151    for  (auto  [i, size] : llvm::enumerate (shape)) {
175152      Value sizeVal = rewriter.create <LLVM::ConstantOp>(
176153          loc, i64Ty, rewriter.getIntegerAttr (i64Ty, size));
@@ -191,7 +168,7 @@ class ToMemrefConverter : public OpConversionPattern<tptr::ToMemrefOp> {
191168};
192169
193170//  FromMemrefOp -> llvm.extractvalue
194- class  FromMemrefConverter  :  public  OpConversionPattern <tptr::FromMemrefOp> {
171+ struct  FromMemrefConverter  : OpConversionPattern<tptr::FromMemrefOp> {
195172  using  OpConversionPattern<tptr::FromMemrefOp>::OpConversionPattern;
196173
197174  LogicalResult
@@ -221,8 +198,8 @@ class FromMemrefConverter : public OpConversionPattern<tptr::FromMemrefOp> {
221198};
222199
223200//  Clean up unused UnrealizedConversionCast
224- class  UnrealizedCastConverter 
225-     : public  OpConversionPattern<UnrealizedConversionCastOp> {
201+ struct  UnrealizedCastConverter 
202+     : OpConversionPattern<UnrealizedConversionCastOp> {
226203  using  OpConversionPattern<UnrealizedConversionCastOp>::OpConversionPattern;
227204
228205  LogicalResult
@@ -242,7 +219,6 @@ class UnrealizedCastConverter
242219      return  success ();
243220    }
244221
245-     //  Reject unsafe conversions
246222    if  (isa<ptr::PtrType>(outputType) ||
247223        (isa<LLVM::LLVMPointerType>(inputType) &&
248224         isa<ptr::PtrType>(outputType)) ||
@@ -251,7 +227,6 @@ class UnrealizedCastConverter
251227      return  rewriter.notifyMatchFailure (op, " unsafe pointer conversion"  );
252228    }
253229
254-     //  Allowed safe conversions
255230    if  ((isa<LLVM::LLVMStructType>(inputType) && isa<MemRefType>(outputType)) ||
256231        (isa<MemRefType>(inputType) && isa<LLVM::LLVMStructType>(outputType))) {
257232      LDBG (" matchAndRewrite: replace with input: "   << op << "  -> "   << input);
@@ -290,7 +265,7 @@ static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
290265}
291266
292267//  Conditional branch conversion
293- class  ConvertControlFlowOp  :  public  OpConversionPattern <cf::CondBranchOp> {
268+ struct  ConvertControlFlowOp  : OpConversionPattern<cf::CondBranchOp> {
294269  using  OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
295270
296271  LogicalResult
@@ -316,7 +291,7 @@ class ConvertControlFlowOp : public OpConversionPattern<cf::CondBranchOp> {
316291};
317292
318293//  Unconditional branch conversion
319- class  ConvertBranchOp  :  public  OpConversionPattern <cf::BranchOp> {
294+ struct  ConvertBranchOp  : OpConversionPattern<cf::BranchOp> {
320295  using  OpConversionPattern<cf::BranchOp>::OpConversionPattern;
321296
322297  LogicalResult
@@ -338,7 +313,7 @@ class ConvertBranchOp : public OpConversionPattern<cf::BranchOp> {
338313
339314//  MemRef allocation with pointer element types -> LLVM malloc + struct
340315//  construction
341- class  MemRefAllocConverter  :  public  OpConversionPattern <memref::AllocOp> {
316+ struct  MemRefAllocConverter  : OpConversionPattern<memref::AllocOp> {
342317  using  OpConversionPattern<memref::AllocOp>::OpConversionPattern;
343318
344319  LogicalResult
@@ -421,7 +396,7 @@ class MemRefAllocConverter : public OpConversionPattern<memref::AllocOp> {
421396};
422397
423398//  MemRef store with pointer element types -> LLVM GEP + store
424- class  MemRefStoreConverter  :  public  OpConversionPattern <memref::StoreOp> {
399+ struct  MemRefStoreConverter  : OpConversionPattern<memref::StoreOp> {
425400  using  OpConversionPattern<memref::StoreOp>::OpConversionPattern;
426401
427402  LogicalResult
@@ -456,7 +431,9 @@ class MemRefStoreConverter : public OpConversionPattern<memref::StoreOp> {
456431        //  Convert index to i64 if needed
457432        if  (index.getType () != i64Ty) {
458433          if  (isa<IndexType>(index.getType ())) {
459-             index = rewriter.create <UnrealizedConversionCastOp>(loc, i64Ty, index).getResult (0 );
434+             index =
435+                 rewriter.create <UnrealizedConversionCastOp>(loc, i64Ty, index)
436+                     .getResult (0 );
460437          }
461438        }
462439        linearIndex = index;
@@ -470,21 +447,25 @@ class MemRefStoreConverter : public OpConversionPattern<memref::StoreOp> {
470447          Value convertedIndex = index;
471448          if  (index.getType () != i64Ty) {
472449            if  (isa<IndexType>(index.getType ())) {
473-               convertedIndex = rewriter.create <UnrealizedConversionCastOp>(loc, i64Ty, index).getResult (0 );
450+               convertedIndex =
451+                   rewriter.create <UnrealizedConversionCastOp>(loc, i64Ty, index)
452+                       .getResult (0 );
474453            }
475454          }
476455
477456          Value stride = rewriter.create <LLVM::ExtractValueOp>(
478457              loc, i64Ty, memrefDescriptor,
479458              rewriter.getDenseI64ArrayAttr ({4 , static_cast <int64_t >(i)}));
480-           Value contribution = rewriter.create <LLVM::MulOp>(loc, convertedIndex, stride);
481-           linearIndex = rewriter.create <LLVM::AddOp>(loc, linearIndex, contribution);
459+           Value contribution =
460+               rewriter.create <LLVM::MulOp>(loc, convertedIndex, stride);
461+           linearIndex =
462+               rewriter.create <LLVM::AddOp>(loc, linearIndex, contribution);
482463        }
483464      }
484465
485466      //  GEP to get the address of the element
486-       Value elementPtr = rewriter. create <LLVM::GEPOp>( 
487-           loc, ptrTy, ptrTy, basePtr, linearIndex);
467+       Value elementPtr =
468+           rewriter. create <LLVM::GEPOp>( loc, ptrTy, ptrTy, basePtr, linearIndex);
488469
489470      //  Store the value
490471      rewriter.create <LLVM::StoreOp>(loc, adaptor.getValue (), elementPtr);
@@ -499,7 +480,7 @@ class MemRefStoreConverter : public OpConversionPattern<memref::StoreOp> {
499480};
500481
501482//  MemRef load with pointer element types -> LLVM GEP + load
502- class  MemRefLoadConverter  :  public  OpConversionPattern <memref::LoadOp> {
483+ struct  MemRefLoadConverter  : OpConversionPattern<memref::LoadOp> {
503484  using  OpConversionPattern<memref::LoadOp>::OpConversionPattern;
504485
505486  LogicalResult
@@ -540,7 +521,9 @@ class MemRefLoadConverter : public OpConversionPattern<memref::LoadOp> {
540521        //  Convert index to i64 if needed
541522        if  (index.getType () != i64Ty) {
542523          if  (isa<IndexType>(index.getType ())) {
543-             index = rewriter.create <UnrealizedConversionCastOp>(loc, i64Ty, index).getResult (0 );
524+             index =
525+                 rewriter.create <UnrealizedConversionCastOp>(loc, i64Ty, index)
526+                     .getResult (0 );
544527          }
545528        }
546529        linearIndex = index;
@@ -554,24 +537,29 @@ class MemRefLoadConverter : public OpConversionPattern<memref::LoadOp> {
554537          Value convertedIndex = index;
555538          if  (index.getType () != i64Ty) {
556539            if  (isa<IndexType>(index.getType ())) {
557-               convertedIndex = rewriter.create <UnrealizedConversionCastOp>(loc, i64Ty, index).getResult (0 );
540+               convertedIndex =
541+                   rewriter.create <UnrealizedConversionCastOp>(loc, i64Ty, index)
542+                       .getResult (0 );
558543            }
559544          }
560545
561546          Value stride = rewriter.create <LLVM::ExtractValueOp>(
562547              loc, i64Ty, memrefDescriptor,
563548              rewriter.getDenseI64ArrayAttr ({4 , static_cast <int64_t >(i)}));
564-           Value contribution = rewriter.create <LLVM::MulOp>(loc, convertedIndex, stride);
565-           linearIndex = rewriter.create <LLVM::AddOp>(loc, linearIndex, contribution);
549+           Value contribution =
550+               rewriter.create <LLVM::MulOp>(loc, convertedIndex, stride);
551+           linearIndex =
552+               rewriter.create <LLVM::AddOp>(loc, linearIndex, contribution);
566553        }
567554      }
568555
569556      //  GEP to get the address of the element
570-       Value elementPtr = rewriter. create <LLVM::GEPOp>( 
571-           loc, ptrTy, ptrTy, basePtr, linearIndex);
557+       Value elementPtr =
558+           rewriter. create <LLVM::GEPOp>( loc, ptrTy, ptrTy, basePtr, linearIndex);
572559
573560      //  Load the value
574-       Value loadedValue = rewriter.create <LLVM::LoadOp>(loc, newResultType, elementPtr);
561+       Value loadedValue =
562+           rewriter.create <LLVM::LoadOp>(loc, newResultType, elementPtr);
575563      rewriter.replaceOp (op, loadedValue);
576564
577565      LDBG (" matchAndRewrite: memref.load done -> LLVM GEP + load"  );
@@ -583,15 +571,24 @@ class MemRefLoadConverter : public OpConversionPattern<memref::LoadOp> {
583571};
584572
585573//  TypeOffsetOp -> constant conversion
586- class  TypeOffsetConverter  :  public  OpConversionPattern <tptr::TypeOffsetOp> {
574+ struct  TypeOffsetConverter  : OpConversionPattern<tptr::TypeOffsetOp> {
587575  using  OpConversionPattern<tptr::TypeOffsetOp>::OpConversionPattern;
588576
577+   llvm::TypeSize
578+   getTypeSize (tptr::TypeOffsetOp op,
579+               std::optional<DataLayout> layout = std::nullopt ) const  {
580+     if  (layout)
581+       return  layout->getTypeSize (op.getBaseType ());
582+     DataLayout dl = DataLayout::closest (op);
583+     return  dl.getTypeSize (op.getBaseType ());
584+   }
585+ 
589586  LogicalResult
590587  matchAndRewrite (tptr::TypeOffsetOp op, OpAdaptor adaptor,
591588                  ConversionPatternRewriter &rewriter) const  override  {
592589    LDBG (" matchAndRewrite: type_offset "   << op);
593590
594-     auto  size = op. getTypeSize ();
591+     auto  size = getTypeSize (op );
595592    auto  constOp = rewriter.create <LLVM::ConstantOp>(
596593        op.getLoc (), op.getType (), rewriter.getIntegerAttr (op.getType (), size));
597594
0 commit comments