55#include " mlir/Dialect/LLVMIR/LLVMTypes.h"
66#include " mlir/Dialect/MemRef/IR/MemRef.h"
77#include " mlir/Dialect/Ptr/IR/PtrTypes.h"
8+ #include " mlir/IR/BuiltinAttributes.h"
9+ #include " mlir/IR/BuiltinOps.h"
810#include " mlir/IR/BuiltinTypes.h"
911#include " mlir/IR/PatternMatch.h"
1012#include " mlir/IR/Value.h"
@@ -26,7 +28,6 @@ static bool isOneToOneCast(UnrealizedConversionCastOp op) {
2628 return (op.getInputs ().size () == 1 && op->getNumResults () == 1 );
2729}
2830
29-
3031// PtrAddOp -> llvm.getelementptr conversion
3132struct PtrAddConverter : OpConversionPattern<tptr::PtrAddOp> {
3233 using OpConversionPattern<tptr::PtrAddOp>::OpConversionPattern;
@@ -109,8 +110,10 @@ struct ToMemrefConverter : OpConversionPattern<tptr::ToMemrefOp> {
109110 }
110111 }
111112
112- Type targetType = getTypeConverter ()->convertType (cast<MemRefType>(op.getType ()));
113- LDBG (" matchAndRewrite: to_memref (typeconverted) " << cast<MemRefType>(op.getType ()) << " -> " << targetType);
113+ Type targetType =
114+ getTypeConverter ()->convertType (cast<MemRefType>(op.getType ()));
115+ LDBG (" matchAndRewrite: to_memref (typeconverted) "
116+ << cast<MemRefType>(op.getType ()) << " -> " << targetType);
114117 if (!targetType) {
115118 return rewriter.notifyMatchFailure (op, " failed to convert memref type" );
116119 }
@@ -166,14 +169,10 @@ struct FromMemrefConverter : OpConversionPattern<tptr::FromMemrefOp> {
166169 LDBG (" matchAndRewrite: from_memref (before) " << op);
167170
168171 Value input = adaptor.getInput ();
172+ // 期望此处的输入已通过 TypeConverter 转换为目标 LLVM 结构体类型
169173 if (isa<MemRefType>(input.getType ())) {
170- input = rewriter
171- .create <UnrealizedConversionCastOp>(
172- op.getLoc (),
173- getTypeConverter ()->convertType (
174- cast<MemRefType>(input.getType ())),
175- input)
176- .getResult (0 );
174+ return rewriter.notifyMatchFailure (op,
175+ " expected converted memref descriptor" );
177176 }
178177
179178 // Extract base_ptr (index 0)
@@ -210,7 +209,8 @@ struct UnrealizedCastConverter
210209 }
211210
212211 if (isa<ptr::PtrType>(outputType) ||
213- (isa<LLVM::LLVMPointerType>(inputType) && isa<MemRefType>(outputType))) {
212+ (isa<LLVM::LLVMPointerType>(inputType) &&
213+ isa<MemRefType>(outputType))) {
214214 LDBG (" UnrealizedCast (reject): unsafe pointer conversion " << op);
215215 return rewriter.notifyMatchFailure (op, " unsafe pointer conversion" );
216216 }
@@ -292,9 +292,8 @@ struct ConvertBranchOp : OpConversionPattern<cf::BranchOp> {
292292 return failure ();
293293 }
294294
295- auto newOp =
296- rewriter.replaceOpWithNewOp <cf::BranchOp>(op, op.getDest (),
297- adaptor.getDestOperands ());
295+ auto newOp = rewriter.replaceOpWithNewOp <cf::BranchOp>(
296+ op, op.getDest (), adaptor.getDestOperands ());
298297 LDBG (" matchAndRewrite: cf.br (after) -> " << newOp);
299298 return success ();
300299 }
@@ -341,12 +340,44 @@ struct MemRefAllocConverter : OpConversionPattern<memref::AllocOp> {
341340 totalElements *= dim;
342341 }
343342
344- // For now, use alloca instead of malloc to avoid complex call setup
345- Value totalSize = rewriter.create <LLVM::ConstantOp>(
343+ // Compute total allocation size in bytes = numElements * sizeof(element)
344+ Value numElementsVal = rewriter.create <LLVM::ConstantOp>(
346345 loc, i64Ty, rewriter.getIntegerAttr (i64Ty, totalElements));
347346
348- Value allocatedPtr = rewriter.create <LLVM::AllocaOp>(
349- loc, ptrTy, ptrTy, totalSize, /* alignment=*/ 0 );
347+ // Query pointer size from DataLayout
348+ DataLayout dl = DataLayout::closest (op);
349+ auto ptrSize = dl.getTypeSize (ptrTy);
350+ if (ptrSize.isScalable ()) {
351+ return rewriter.notifyMatchFailure (op,
352+ " scalable pointer size unsupported" );
353+ }
354+ auto fixedPtrSize = static_cast <int64_t >(ptrSize.getFixedValue ());
355+ Value ptrSizeVal = rewriter.create <LLVM::ConstantOp>(
356+ loc, i64Ty, rewriter.getIntegerAttr (i64Ty, fixedPtrSize));
357+
358+ Value totalBytes =
359+ rewriter.create <LLVM::MulOp>(loc, numElementsVal, ptrSizeVal);
360+
361+ // Declare or lookup malloc: ptr (i64)
362+ ModuleOp module = op->getParentOfType <ModuleOp>();
363+ auto mallocName = StringRef (" malloc" );
364+ LLVM::LLVMFuncOp mallocFunc =
365+ module .lookupSymbol <LLVM::LLVMFuncOp>(mallocName);
366+ if (!mallocFunc) {
367+ OpBuilder::InsertionGuard guard (rewriter);
368+ rewriter.setInsertionPointToStart (module .getBody ());
369+ auto mallocType =
370+ LLVM::LLVMFunctionType::get (ptrTy, {i64Ty}, /* isVarArg=*/ false );
371+ mallocFunc =
372+ rewriter.create <LLVM::LLVMFuncOp>(loc, mallocName, mallocType);
373+ }
374+
375+ auto mallocCallee = SymbolRefAttr::get (mallocFunc);
376+ Value allocatedPtr =
377+ rewriter
378+ .create <LLVM::CallOp>(loc, TypeRange{ptrTy}, mallocCallee,
379+ ValueRange{totalBytes})
380+ .getResult ();
350381
351382 // Build memref descriptor struct
352383 Value result = rewriter.create <LLVM::UndefOp>(loc, llvmStructType);
@@ -407,10 +438,13 @@ struct MemRefStoreConverter : OpConversionPattern<memref::StoreOp> {
407438 auto ptrTy = LLVM::LLVMPointerType::get (ctx);
408439 auto i64Ty = rewriter.getIntegerType (64 );
409440
410- // Extract base pointer from memref descriptor (index 0)
441+ // Extract aligned pointer and offset from memref descriptor
442+ // aligned_ptr at index 1, offset at index 2
411443 Value memrefDescriptor = adaptor.getMemref ();
412- Value basePtr = rewriter.create <LLVM::ExtractValueOp>(
413- loc, ptrTy, memrefDescriptor, rewriter.getDenseI64ArrayAttr ({0 }));
444+ Value alignedPtr = rewriter.create <LLVM::ExtractValueOp>(
445+ loc, ptrTy, memrefDescriptor, rewriter.getDenseI64ArrayAttr ({1 }));
446+ Value baseOffset = rewriter.create <LLVM::ExtractValueOp>(
447+ loc, i64Ty, memrefDescriptor, rewriter.getDenseI64ArrayAttr ({2 }));
414448
415449 // Calculate linear index from multi-dimensional indices
416450 Value linearIndex = nullptr ;
@@ -425,11 +459,10 @@ struct MemRefStoreConverter : OpConversionPattern<memref::StoreOp> {
425459 .getResult (0 );
426460 }
427461 }
428- linearIndex = index;
462+ linearIndex = rewriter. create <LLVM::AddOp>(loc, baseOffset, index) ;
429463 } else {
430464 // Multi-dimensional: linearIndex = i0*stride0 + i1*stride1 + ...
431- linearIndex = rewriter.create <LLVM::ConstantOp>(
432- loc, i64Ty, rewriter.getIntegerAttr (i64Ty, 0 ));
465+ linearIndex = baseOffset;
433466
434467 for (auto [i, index] : llvm::enumerate (adaptor.getIndices ())) {
435468 // Convert index to i64 if needed
@@ -453,8 +486,8 @@ struct MemRefStoreConverter : OpConversionPattern<memref::StoreOp> {
453486 }
454487
455488 // GEP to get the address of the element
456- Value elementPtr =
457- rewriter. create <LLVM::GEPOp>(loc, ptrTy, ptrTy, basePtr , linearIndex);
489+ Value elementPtr = rewriter. create <LLVM::GEPOp>(loc, ptrTy, ptrTy,
490+ alignedPtr , linearIndex);
458491
459492 // Store the value
460493 auto storeOp =
@@ -497,16 +530,24 @@ struct MemRefLoadConverter : OpConversionPattern<memref::LoadOp> {
497530 auto ptrTy = LLVM::LLVMPointerType::get (ctx);
498531 auto i64Ty = rewriter.getIntegerType (64 );
499532
500- // Extract base pointer from memref descriptor (index 0)
533+ // Extract aligned pointer and offset from memref descriptor
534+ // aligned_ptr at index 1, offset at index 2
501535 Value memrefDescriptor = adaptor.getMemref ();
502- Value basePtr = rewriter.create <LLVM::ExtractValueOp>(
503- loc, ptrTy, memrefDescriptor, rewriter.getDenseI64ArrayAttr ({0 }));
504-
536+ LDBG (" memrefDescriptor " << memrefDescriptor);
537+ Value alignedPtr = rewriter.create <LLVM::ExtractValueOp>(
538+ loc, ptrTy, memrefDescriptor, rewriter.getDenseI64ArrayAttr ({1 }));
539+ LDBG (" basePtr " << rewriter.create <LLVM::ExtractValueOp>(
540+ loc, ptrTy, memrefDescriptor, rewriter.getDenseI64ArrayAttr ({0 })));
541+ LDBG (" alignedPtr " << alignedPtr);
542+ Value baseOffset = rewriter.create <LLVM::ExtractValueOp>(
543+ loc, i64Ty, memrefDescriptor, rewriter.getDenseI64ArrayAttr ({2 }));
544+ LDBG (" baseOffset " << baseOffset);
505545 // Calculate linear index from multi-dimensional indices
506546 Value linearIndex = nullptr ;
507547 if (adaptor.getIndices ().size () == 1 ) {
508548 // Single dimension case
509549 Value index = adaptor.getIndices ()[0 ];
550+ LDBG (" if index " << index);
510551 // Convert index to i64 if needed
511552 if (index.getType () != i64Ty) {
512553 if (isa<IndexType>(index.getType ())) {
@@ -515,11 +556,11 @@ struct MemRefLoadConverter : OpConversionPattern<memref::LoadOp> {
515556 .getResult (0 );
516557 }
517558 }
518- linearIndex = index;
559+ linearIndex = rewriter. create <LLVM::AddOp>(loc, baseOffset, index) ;
519560 } else {
520561 // Multi-dimensional: linearIndex = i0*stride0 + i1*stride1 + ...
521- linearIndex = rewriter. create <LLVM::ConstantOp>(
522- loc, i64Ty, rewriter. getIntegerAttr (i64Ty, 0 ) );
562+ linearIndex = baseOffset;
563+ LDBG ( " else index " << linearIndex );
523564
524565 for (auto [i, index] : llvm::enumerate (adaptor.getIndices ())) {
525566 // Convert index to i64 if needed
@@ -543,8 +584,8 @@ struct MemRefLoadConverter : OpConversionPattern<memref::LoadOp> {
543584 }
544585
545586 // GEP to get the address of the element
546- Value elementPtr =
547- rewriter. create <LLVM::GEPOp>(loc, ptrTy, ptrTy, basePtr , linearIndex);
587+ Value elementPtr = rewriter. create <LLVM::GEPOp>(loc, ptrTy, ptrTy,
588+ alignedPtr , linearIndex);
548589
549590 // Load the value
550591 Value loadedValue =
0 commit comments