Skip to content

Commit 3df3d07

Browse files
committed
new review from Dasor
1 parent 37fc68a commit 3df3d07

File tree

2 files changed

+107
-44
lines changed

2 files changed

+107
-44
lines changed

lib/Conversion/TPtrToLLVM/TPtrToLLVM.cpp

Lines changed: 76 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
66
#include "mlir/Dialect/MemRef/IR/MemRef.h"
77
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
8+
#include "mlir/IR/BuiltinAttributes.h"
9+
#include "mlir/IR/BuiltinOps.h"
810
#include "mlir/IR/BuiltinTypes.h"
911
#include "mlir/IR/PatternMatch.h"
1012
#include "mlir/IR/Value.h"
@@ -26,7 +28,6 @@ static bool isOneToOneCast(UnrealizedConversionCastOp op) {
2628
return (op.getInputs().size() == 1 && op->getNumResults() == 1);
2729
}
2830

29-
3031
// PtrAddOp -> llvm.getelementptr conversion
3132
struct PtrAddConverter : OpConversionPattern<tptr::PtrAddOp> {
3233
using OpConversionPattern<tptr::PtrAddOp>::OpConversionPattern;
@@ -109,8 +110,10 @@ struct ToMemrefConverter : OpConversionPattern<tptr::ToMemrefOp> {
109110
}
110111
}
111112

112-
Type targetType = getTypeConverter()->convertType(cast<MemRefType>(op.getType()));
113-
LDBG("matchAndRewrite: to_memref (typeconverted) " << cast<MemRefType>(op.getType()) << " -> " << targetType);
113+
Type targetType =
114+
getTypeConverter()->convertType(cast<MemRefType>(op.getType()));
115+
LDBG("matchAndRewrite: to_memref (typeconverted) "
116+
<< cast<MemRefType>(op.getType()) << " -> " << targetType);
114117
if (!targetType) {
115118
return rewriter.notifyMatchFailure(op, "failed to convert memref type");
116119
}
@@ -166,14 +169,10 @@ struct FromMemrefConverter : OpConversionPattern<tptr::FromMemrefOp> {
166169
LDBG("matchAndRewrite: from_memref (before) " << op);
167170

168171
Value input = adaptor.getInput();
172+
// 期望此处的输入已通过 TypeConverter 转换为目标 LLVM 结构体类型
169173
if (isa<MemRefType>(input.getType())) {
170-
input = rewriter
171-
.create<UnrealizedConversionCastOp>(
172-
op.getLoc(),
173-
getTypeConverter()->convertType(
174-
cast<MemRefType>(input.getType())),
175-
input)
176-
.getResult(0);
174+
return rewriter.notifyMatchFailure(op,
175+
"expected converted memref descriptor");
177176
}
178177

179178
// Extract base_ptr (index 0)
@@ -210,7 +209,8 @@ struct UnrealizedCastConverter
210209
}
211210

212211
if (isa<ptr::PtrType>(outputType) ||
213-
(isa<LLVM::LLVMPointerType>(inputType) && isa<MemRefType>(outputType))) {
212+
(isa<LLVM::LLVMPointerType>(inputType) &&
213+
isa<MemRefType>(outputType))) {
214214
LDBG("UnrealizedCast (reject): unsafe pointer conversion " << op);
215215
return rewriter.notifyMatchFailure(op, "unsafe pointer conversion");
216216
}
@@ -292,9 +292,8 @@ struct ConvertBranchOp : OpConversionPattern<cf::BranchOp> {
292292
return failure();
293293
}
294294

295-
auto newOp =
296-
rewriter.replaceOpWithNewOp<cf::BranchOp>(op, op.getDest(),
297-
adaptor.getDestOperands());
295+
auto newOp = rewriter.replaceOpWithNewOp<cf::BranchOp>(
296+
op, op.getDest(), adaptor.getDestOperands());
298297
LDBG("matchAndRewrite: cf.br (after) -> " << newOp);
299298
return success();
300299
}
@@ -341,12 +340,44 @@ struct MemRefAllocConverter : OpConversionPattern<memref::AllocOp> {
341340
totalElements *= dim;
342341
}
343342

344-
// For now, use alloca instead of malloc to avoid complex call setup
345-
Value totalSize = rewriter.create<LLVM::ConstantOp>(
343+
// Compute total allocation size in bytes = numElements * sizeof(element)
344+
Value numElementsVal = rewriter.create<LLVM::ConstantOp>(
346345
loc, i64Ty, rewriter.getIntegerAttr(i64Ty, totalElements));
347346

348-
Value allocatedPtr = rewriter.create<LLVM::AllocaOp>(
349-
loc, ptrTy, ptrTy, totalSize, /*alignment=*/0);
347+
// Query pointer size from DataLayout
348+
DataLayout dl = DataLayout::closest(op);
349+
auto ptrSize = dl.getTypeSize(ptrTy);
350+
if (ptrSize.isScalable()) {
351+
return rewriter.notifyMatchFailure(op,
352+
"scalable pointer size unsupported");
353+
}
354+
auto fixedPtrSize = static_cast<int64_t>(ptrSize.getFixedValue());
355+
Value ptrSizeVal = rewriter.create<LLVM::ConstantOp>(
356+
loc, i64Ty, rewriter.getIntegerAttr(i64Ty, fixedPtrSize));
357+
358+
Value totalBytes =
359+
rewriter.create<LLVM::MulOp>(loc, numElementsVal, ptrSizeVal);
360+
361+
// Declare or lookup malloc: ptr (i64)
362+
ModuleOp module = op->getParentOfType<ModuleOp>();
363+
auto mallocName = StringRef("malloc");
364+
LLVM::LLVMFuncOp mallocFunc =
365+
module.lookupSymbol<LLVM::LLVMFuncOp>(mallocName);
366+
if (!mallocFunc) {
367+
OpBuilder::InsertionGuard guard(rewriter);
368+
rewriter.setInsertionPointToStart(module.getBody());
369+
auto mallocType =
370+
LLVM::LLVMFunctionType::get(ptrTy, {i64Ty}, /*isVarArg=*/false);
371+
mallocFunc =
372+
rewriter.create<LLVM::LLVMFuncOp>(loc, mallocName, mallocType);
373+
}
374+
375+
auto mallocCallee = SymbolRefAttr::get(mallocFunc);
376+
Value allocatedPtr =
377+
rewriter
378+
.create<LLVM::CallOp>(loc, TypeRange{ptrTy}, mallocCallee,
379+
ValueRange{totalBytes})
380+
.getResult();
350381

351382
// Build memref descriptor struct
352383
Value result = rewriter.create<LLVM::UndefOp>(loc, llvmStructType);
@@ -407,10 +438,13 @@ struct MemRefStoreConverter : OpConversionPattern<memref::StoreOp> {
407438
auto ptrTy = LLVM::LLVMPointerType::get(ctx);
408439
auto i64Ty = rewriter.getIntegerType(64);
409440

410-
// Extract base pointer from memref descriptor (index 0)
441+
// Extract aligned pointer and offset from memref descriptor
442+
// aligned_ptr at index 1, offset at index 2
411443
Value memrefDescriptor = adaptor.getMemref();
412-
Value basePtr = rewriter.create<LLVM::ExtractValueOp>(
413-
loc, ptrTy, memrefDescriptor, rewriter.getDenseI64ArrayAttr({0}));
444+
Value alignedPtr = rewriter.create<LLVM::ExtractValueOp>(
445+
loc, ptrTy, memrefDescriptor, rewriter.getDenseI64ArrayAttr({1}));
446+
Value baseOffset = rewriter.create<LLVM::ExtractValueOp>(
447+
loc, i64Ty, memrefDescriptor, rewriter.getDenseI64ArrayAttr({2}));
414448

415449
// Calculate linear index from multi-dimensional indices
416450
Value linearIndex = nullptr;
@@ -425,11 +459,10 @@ struct MemRefStoreConverter : OpConversionPattern<memref::StoreOp> {
425459
.getResult(0);
426460
}
427461
}
428-
linearIndex = index;
462+
linearIndex = rewriter.create<LLVM::AddOp>(loc, baseOffset, index);
429463
} else {
430464
// Multi-dimensional: linearIndex = i0*stride0 + i1*stride1 + ...
431-
linearIndex = rewriter.create<LLVM::ConstantOp>(
432-
loc, i64Ty, rewriter.getIntegerAttr(i64Ty, 0));
465+
linearIndex = baseOffset;
433466

434467
for (auto [i, index] : llvm::enumerate(adaptor.getIndices())) {
435468
// Convert index to i64 if needed
@@ -453,8 +486,8 @@ struct MemRefStoreConverter : OpConversionPattern<memref::StoreOp> {
453486
}
454487

455488
// GEP to get the address of the element
456-
Value elementPtr =
457-
rewriter.create<LLVM::GEPOp>(loc, ptrTy, ptrTy, basePtr, linearIndex);
489+
Value elementPtr = rewriter.create<LLVM::GEPOp>(loc, ptrTy, ptrTy,
490+
alignedPtr, linearIndex);
458491

459492
// Store the value
460493
auto storeOp =
@@ -497,16 +530,24 @@ struct MemRefLoadConverter : OpConversionPattern<memref::LoadOp> {
497530
auto ptrTy = LLVM::LLVMPointerType::get(ctx);
498531
auto i64Ty = rewriter.getIntegerType(64);
499532

500-
// Extract base pointer from memref descriptor (index 0)
533+
// Extract aligned pointer and offset from memref descriptor
534+
// aligned_ptr at index 1, offset at index 2
501535
Value memrefDescriptor = adaptor.getMemref();
502-
Value basePtr = rewriter.create<LLVM::ExtractValueOp>(
503-
loc, ptrTy, memrefDescriptor, rewriter.getDenseI64ArrayAttr({0}));
504-
536+
LDBG("memrefDescriptor " << memrefDescriptor);
537+
Value alignedPtr = rewriter.create<LLVM::ExtractValueOp>(
538+
loc, ptrTy, memrefDescriptor, rewriter.getDenseI64ArrayAttr({1}));
539+
LDBG("basePtr " << rewriter.create<LLVM::ExtractValueOp>(
540+
loc, ptrTy, memrefDescriptor, rewriter.getDenseI64ArrayAttr({0})));
541+
LDBG("alignedPtr " << alignedPtr);
542+
Value baseOffset = rewriter.create<LLVM::ExtractValueOp>(
543+
loc, i64Ty, memrefDescriptor, rewriter.getDenseI64ArrayAttr({2}));
544+
LDBG("baseOffset " << baseOffset);
505545
// Calculate linear index from multi-dimensional indices
506546
Value linearIndex = nullptr;
507547
if (adaptor.getIndices().size() == 1) {
508548
// Single dimension case
509549
Value index = adaptor.getIndices()[0];
550+
LDBG("if index " << index);
510551
// Convert index to i64 if needed
511552
if (index.getType() != i64Ty) {
512553
if (isa<IndexType>(index.getType())) {
@@ -515,11 +556,11 @@ struct MemRefLoadConverter : OpConversionPattern<memref::LoadOp> {
515556
.getResult(0);
516557
}
517558
}
518-
linearIndex = index;
559+
linearIndex = rewriter.create<LLVM::AddOp>(loc, baseOffset, index);
519560
} else {
520561
// Multi-dimensional: linearIndex = i0*stride0 + i1*stride1 + ...
521-
linearIndex = rewriter.create<LLVM::ConstantOp>(
522-
loc, i64Ty, rewriter.getIntegerAttr(i64Ty, 0));
562+
linearIndex = baseOffset;
563+
LDBG("else index " << linearIndex);
523564

524565
for (auto [i, index] : llvm::enumerate(adaptor.getIndices())) {
525566
// Convert index to i64 if needed
@@ -543,8 +584,8 @@ struct MemRefLoadConverter : OpConversionPattern<memref::LoadOp> {
543584
}
544585

545586
// GEP to get the address of the element
546-
Value elementPtr =
547-
rewriter.create<LLVM::GEPOp>(loc, ptrTy, ptrTy, basePtr, linearIndex);
587+
Value elementPtr = rewriter.create<LLVM::GEPOp>(loc, ptrTy, ptrTy,
588+
alignedPtr, linearIndex);
548589

549590
// Load the value
550591
Value loadedValue =

lib/Conversion/TPtrToLLVM/TPtrToLLVMPass.cpp

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
44
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
55
#include "mlir/Dialect/MemRef/IR/MemRef.h"
6+
#include "mlir/Dialect/Ptr/IR/PtrAttrs.h"
7+
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
68
#include "mlir/IR/BuiltinDialect.h"
79
#include "mlir/IR/BuiltinTypes.h"
810
#include "mlir/IR/PatternMatch.h"
@@ -30,18 +32,9 @@ namespace tptr {
3032
namespace {
3133

3234
struct TptrToLLVMTypeConverter : TypeConverter {
33-
Type convertPtrPointerType(ptr::PtrType type) {
34-
auto ctx = type.getContext();
35-
return LLVM::LLVMPointerType::get(ctx);
36-
}
37-
3835
TptrToLLVMTypeConverter(MLIRContext *ctx) {
3936
addConversion([](Type type) -> Type { return type; });
4037

41-
addConversion([&](ptr::PtrType type) -> std::optional<Type> {
42-
return convertPtrPointerType(type);
43-
});
44-
4538
addConversion([&](MemRefType type) -> std::optional<Type> {
4639
auto elementType = type.getElementType();
4740
auto ctx = type.getContext();
@@ -57,6 +50,35 @@ struct TptrToLLVMTypeConverter : TypeConverter {
5750

5851
return LLVM::LLVMStructType::getLiteral(ctx, types);
5952
});
53+
addTypeAttributeConversion(
54+
[&](PtrLikeTypeInterface type, ptr::GenericSpaceAttr memorySpace)
55+
-> TypeConverter::AttributeConversionResult {
56+
if (type.getMemorySpace() != memorySpace)
57+
return TypeConverter::AttributeConversionResult::na();
58+
return IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0);
59+
});
60+
addTypeAttributeConversion(
61+
[&](PtrLikeTypeInterface type, tptr::DefaultMemorySpaceAttr memorySpace)
62+
-> TypeConverter::AttributeConversionResult {
63+
if (type.getMemorySpace() != memorySpace)
64+
return TypeConverter::AttributeConversionResult::na();
65+
// Default memory space maps to LLVM addrspace(0).
66+
return IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0);
67+
});
68+
69+
// Add type conversions.
70+
addConversion([&](ptr::PtrType type) -> Type {
71+
LDBG("MemorySpace " << type.getMemorySpace());
72+
std::optional<Attribute> maybeAttr =
73+
convertTypeAttribute(type, type.getMemorySpace());
74+
auto memSpace =
75+
maybeAttr ? dyn_cast_or_null<IntegerAttr>(*maybeAttr) : IntegerAttr();
76+
if (!memSpace) {
77+
return {};
78+
}
79+
return LLVM::LLVMPointerType::get(type.getContext(),
80+
memSpace.getValue().getSExtValue());
81+
});
6082

6183
auto createUnrealizedCast = [](OpBuilder &builder, Type resultType,
6284
ValueRange inputs, Location loc) -> Value {

0 commit comments

Comments
 (0)