-
Notifications
You must be signed in to change notification settings - Fork 15k
[mlir][ptr] Add conversion to LLVM for all existing ptr ops
#156053
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
01b13b1 to
55e9a9b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This patch adds conversion support from the ptr dialect to the LLVM dialect, enabling users to convert pointer operations and types to their LLVM equivalents. This is a stopgap measure allowing immediate use of the ptr dialect before some conversions are moved to translation in the future.
- Implements LLVM conversions for all existing
ptrdialect operations - Adds comprehensive test coverage for various pointer and memref scenarios
- Integrates the conversion into the MLIR build system and registration
Reviewed Changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir | Comprehensive test suite covering all ptr dialect operations and edge cases |
| mlir/lib/RegisterAllExtensions.cpp | Registers the new PtrToLLVM conversion interface |
| mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp | Core implementation of conversion patterns for all ptr dialect operations |
| mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt | Build configuration for the new conversion library |
| mlir/lib/Conversion/CMakeLists.txt | Adds PtrToLLVM subdirectory to the build |
| mlir/include/mlir/Conversion/PtrToLLVM/PtrToLLVM.h | Header file with public API declarations |
|
@llvm/pr-subscribers-mlir Author: Fabian Mora (fabianmcg) ChangesThis patch adds conversion to LLVM for all existing pointer ops. This is a stop gap measure to allow users to use the Example: func.func @<!-- -->test_memref_ptradd_indexing(%arg0: memref<10x?x30xf32, #ptr.generic_space>, %arg1: index) -> !ptr.ptr<#ptr.generic_space> {
%0 = ptr.to_ptr %arg0 : memref<10x?x30xf32, #ptr.generic_space> -> <#ptr.generic_space>
%1 = ptr.type_offset f32 : index
%2 = arith.muli %1, %arg1 : index
%3 = ptr.ptr_add %0, %2 : <#ptr.generic_space>, index
return %3 : !ptr.ptr<#ptr.generic_space>
}
// mlir-opt --convert-to-llvm --canonicalize --cse
llvm.func @<!-- -->test_memref_ptradd_indexing(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: i64, %arg8: i64, %arg9: i64) -> !llvm.ptr {
%0 = llvm.mlir.zero : !llvm.ptr
%1 = llvm.getelementptr %0[1] : (!llvm.ptr) -> !llvm.ptr, f32
%2 = llvm.ptrtoint %1 : !llvm.ptr to i64
%3 = llvm.mul %2, %arg9 : i64
%4 = llvm.getelementptr %arg1[%3] : (!llvm.ptr, i64) -> !llvm.ptr, i8
llvm.return %4 : !llvm.ptr
}Patch is 51.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156053.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Conversion/PtrToLLVM/PtrToLLVM.h b/mlir/include/mlir/Conversion/PtrToLLVM/PtrToLLVM.h
new file mode 100644
index 0000000000000..0ff92bc85668c
--- /dev/null
+++ b/mlir/include/mlir/Conversion/PtrToLLVM/PtrToLLVM.h
@@ -0,0 +1,27 @@
+//===- PtrToLLVM.h - Ptr to LLVM dialect conversion -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_PTRTOLLVM_PTRTOLLVM_H
+#define MLIR_CONVERSION_PTRTOLLVM_PTRTOLLVM_H
+
+#include <memory>
+
+namespace mlir {
+class DialectRegistry;
+class LLVMTypeConverter;
+class RewritePatternSet;
+namespace ptr {
+/// Populate the convert to LLVM patterns for the `ptr` dialect.
+void populatePtrToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
+/// Register the convert to LLVM interface for the `ptr` dialect.
+void registerConvertPtrToLLVMInterface(DialectRegistry ®istry);
+} // namespace ptr
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_PTRTOLLVM_PTRTOLLVM_H
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 134fe8e14ca38..71986f83c4870 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -50,6 +50,7 @@ add_subdirectory(NVVMToLLVM)
add_subdirectory(OpenACCToSCF)
add_subdirectory(OpenMPToLLVM)
add_subdirectory(PDLToPDLInterp)
+add_subdirectory(PtrToLLVM)
add_subdirectory(ReconcileUnrealizedCasts)
add_subdirectory(SCFToControlFlow)
add_subdirectory(SCFToEmitC)
diff --git a/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt b/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt
new file mode 100644
index 0000000000000..2d416be13ee30
--- /dev/null
+++ b/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRPtrToLLVM
+ PtrToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/PtrToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRPtrDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ )
diff --git a/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp b/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp
new file mode 100644
index 0000000000000..a0758aa8b1369
--- /dev/null
+++ b/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp
@@ -0,0 +1,440 @@
+//===- PtrToLLVM.cpp - Ptr to LLVM dialect conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/PtrToLLVM/PtrToLLVM.h"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/Ptr/IR/PtrOps.h"
+#include "mlir/IR/TypeUtilities.h"
+#include <type_traits>
+
+using namespace mlir;
+
+namespace {
+//===----------------------------------------------------------------------===//
+// FromPtrOpConversion
+//===----------------------------------------------------------------------===//
+struct FromPtrOpConversion : public ConvertOpToLLVMPattern<ptr::FromPtrOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(ptr::FromPtrOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+//===----------------------------------------------------------------------===//
+// GetMetadataOpConversion
+//===----------------------------------------------------------------------===//
+struct GetMetadataOpConversion
+ : public ConvertOpToLLVMPattern<ptr::GetMetadataOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(ptr::GetMetadataOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+//===----------------------------------------------------------------------===//
+// PtrAddOpConversion
+//===----------------------------------------------------------------------===//
+struct PtrAddOpConversion : public ConvertOpToLLVMPattern<ptr::PtrAddOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+//===----------------------------------------------------------------------===//
+// ToPtrOpConversion
+//===----------------------------------------------------------------------===//
+struct ToPtrOpConversion : public ConvertOpToLLVMPattern<ptr::ToPtrOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+//===----------------------------------------------------------------------===//
+// TypeOffsetOpConversion
+//===----------------------------------------------------------------------===//
+struct TypeOffsetOpConversion
+ : public ConvertOpToLLVMPattern<ptr::TypeOffsetOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(ptr::TypeOffsetOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Internal functions
+//===----------------------------------------------------------------------===//
+
+// Function to create an LLVM struct type representing a memref metadata.
+static FailureOr<LLVM::LLVMStructType>
+createMemRefMetadataType(MemRefType type,
+ const LLVMTypeConverter &typeConverter) {
+ MLIRContext *context = type.getContext();
+ // Get the address space.
+ FailureOr<unsigned> addressSpace = typeConverter.getMemRefAddressSpace(type);
+ if (failed(addressSpace))
+ return failure();
+
+ // Get pointer type (using address space 0 by default)
+ auto ptrType = LLVM::LLVMPointerType::get(context, *addressSpace);
+
+ // Get the strides offsets and shape.
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ if (failed(type.getStridesAndOffset(strides, offset)))
+ return failure();
+ ArrayRef<int64_t> shape = type.getShape();
+
+ // Use index type from the type converter for the descriptor elements
+ Type indexType = typeConverter.getIndexType();
+
+ // For a ranked memref, the descriptor contains:
+ // 1. The pointer to the allocated data
+ // 2. The pointer to the aligned data
+ // 3. The dynamic offset?
+ // 4. The dynamic sizes?
+ // 5. The dynamic strides?
+ SmallVector<Type, 5> elements;
+
+ // Allocated pointer.
+ elements.push_back(ptrType);
+
+ // Potentially add the dynamic offset.
+ if (offset == ShapedType::kDynamic)
+ elements.push_back(indexType);
+
+ // Potentially add the dynamic sizes.
+ for (int64_t dim : shape) {
+ if (dim == ShapedType::kDynamic)
+ elements.push_back(indexType);
+ }
+
+ // Potentially add the dynamic strides.
+ for (int64_t stride : strides) {
+ if (stride == ShapedType::kDynamic)
+ elements.push_back(indexType);
+ }
+ return LLVM::LLVMStructType::getLiteral(context, elements);
+}
+
+//===----------------------------------------------------------------------===//
+// FromPtrOpConversion
+//===----------------------------------------------------------------------===//
+
+LogicalResult FromPtrOpConversion::matchAndRewrite(
+ ptr::FromPtrOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Get the target memref type
+ auto mTy = dyn_cast<MemRefType>(op.getResult().getType());
+ if (!mTy)
+ return rewriter.notifyMatchFailure(op, "Expected memref result type");
+
+ if (!op.getMetadata() && op.getType().hasPtrMetadata()) {
+ return rewriter.notifyMatchFailure(
+ op, "Can convert only memrefs with metadata");
+ }
+
+ // Convert the result type
+ Type descriptorTy = getTypeConverter()->convertType(mTy);
+ if (!descriptorTy)
+ return rewriter.notifyMatchFailure(op, "Failed to convert result type");
+
+ // Get the strides, offsets and shape.
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ if (failed(mTy.getStridesAndOffset(strides, offset))) {
+ return rewriter.notifyMatchFailure(op,
+ "Failed to get the strides and offset");
+ }
+ ArrayRef<int64_t> shape = mTy.getShape();
+
+ // Create a new memref descriptor
+ Location loc = op.getLoc();
+ auto desc = MemRefDescriptor::poison(rewriter, loc, descriptorTy);
+
+ // Set the allocated and aligned pointers.
+ desc.setAllocatedPtr(
+ rewriter, loc,
+ rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getMetadata(), 0));
+ desc.setAlignedPtr(rewriter, loc, adaptor.getPtr());
+
+ // Extract metadata from the passed struct.
+ unsigned fieldIdx = 1;
+
+ // Set dynamic offset if needed.
+ if (offset == ShapedType::kDynamic) {
+ Value offsetValue = rewriter.create<LLVM::ExtractValueOp>(
+ loc, adaptor.getMetadata(), fieldIdx++);
+ desc.setOffset(rewriter, loc, offsetValue);
+ } else {
+ desc.setConstantOffset(rewriter, loc, offset);
+ }
+
+ // Set dynamic sizes if needed.
+ for (auto [i, dim] : llvm::enumerate(shape)) {
+ if (dim == ShapedType::kDynamic) {
+ Value sizeValue = rewriter.create<LLVM::ExtractValueOp>(
+ loc, adaptor.getMetadata(), fieldIdx++);
+ desc.setSize(rewriter, loc, i, sizeValue);
+ } else {
+ desc.setConstantSize(rewriter, loc, i, dim);
+ }
+ }
+
+ // Set dynamic strides if needed.
+ for (auto [i, stride] : llvm::enumerate(strides)) {
+ if (stride == ShapedType::kDynamic) {
+ Value strideValue = rewriter.create<LLVM::ExtractValueOp>(
+ loc, adaptor.getMetadata(), fieldIdx++);
+ desc.setStride(rewriter, loc, i, strideValue);
+ } else {
+ desc.setConstantStride(rewriter, loc, i, stride);
+ }
+ }
+
+ rewriter.replaceOp(op, static_cast<Value>(desc));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GetMetadataOpConversion
+//===----------------------------------------------------------------------===//
+
+LogicalResult GetMetadataOpConversion::matchAndRewrite(
+ ptr::GetMetadataOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ auto mTy = dyn_cast<MemRefType>(op.getPtr().getType());
+ if (!mTy)
+ return rewriter.notifyMatchFailure(op, "Only memref metadata is supported");
+
+ // Get the metadata type.
+ FailureOr<LLVM::LLVMStructType> mdTy =
+ createMemRefMetadataType(mTy, *getTypeConverter());
+ if (failed(mdTy)) {
+ return rewriter.notifyMatchFailure(op,
+ "Failed to create the metadata type");
+ }
+
+ // Get the memref descriptor.
+ MemRefDescriptor descriptor(adaptor.getPtr());
+
+ // Get the strides offsets and shape.
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ if (failed(mTy.getStridesAndOffset(strides, offset))) {
+ return rewriter.notifyMatchFailure(op,
+ "Failed to get the strides and offset");
+ }
+ ArrayRef<int64_t> shape = mTy.getShape();
+
+ // Create a new LLVM struct to hold the metadata
+ Location loc = op.getLoc();
+ Value sV = rewriter.create<LLVM::UndefOp>(loc, *mdTy);
+
+ // First element is the allocated pointer.
+ sV = rewriter.create<LLVM::InsertValueOp>(
+ loc, sV, descriptor.allocatedPtr(rewriter, loc), 0);
+
+ // Track the current field index.
+ unsigned fieldIdx = 1;
+
+ // Add dynamic offset if needed.
+ if (offset == ShapedType::kDynamic) {
+ sV = rewriter.create<LLVM::InsertValueOp>(
+ loc, sV, descriptor.offset(rewriter, loc), fieldIdx++);
+ }
+
+ // Add dynamic sizes if needed.
+ for (auto [i, dim] : llvm::enumerate(shape)) {
+ if (dim != ShapedType::kDynamic)
+ continue;
+ sV = rewriter.create<LLVM::InsertValueOp>(
+ loc, sV, descriptor.size(rewriter, loc, i), fieldIdx++);
+ }
+
+ // Add dynamic strides if needed
+ for (auto [i, stride] : llvm::enumerate(strides)) {
+ if (stride != ShapedType::kDynamic)
+ continue;
+ sV = rewriter.create<LLVM::InsertValueOp>(
+ loc, sV, descriptor.stride(rewriter, loc, i), fieldIdx++);
+ }
+ rewriter.replaceOp(op, sV);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// PtrAddOpConversion
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+PtrAddOpConversion::matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Get and check the base.
+ Value base = adaptor.getBase();
+ if (!isa<LLVM::LLVMPointerType>(base.getType()))
+ return rewriter.notifyMatchFailure(op, "Incompatible pointer type");
+
+ // Get the offset.
+ Value offset = adaptor.getOffset();
+
+ // Ptr assumes the offset is in bytes.
+ Type elementType = IntegerType::get(rewriter.getContext(), 8);
+
+ // Convert the `ptradd` flags.
+ LLVM::GEPNoWrapFlags flags;
+ switch (op.getFlags()) {
+ case ptr::PtrAddFlags::none:
+ flags = LLVM::GEPNoWrapFlags::none;
+ break;
+ case ptr::PtrAddFlags::nusw:
+ flags = LLVM::GEPNoWrapFlags::nusw;
+ break;
+ case ptr::PtrAddFlags::nuw:
+ flags = LLVM::GEPNoWrapFlags::nuw;
+ break;
+ case ptr::PtrAddFlags::inbounds:
+ flags = LLVM::GEPNoWrapFlags::inbounds;
+ break;
+ }
+
+ // Create the GEP operation with appropriate arguments
+ rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, base.getType(), elementType,
+ base, ValueRange{offset}, flags);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ToPtrOpConversion
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ToPtrOpConversion::matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Bail if it's not a memref.
+ if (!isa<MemRefType>(op.getPtr().getType()))
+ return rewriter.notifyMatchFailure(op, "Expected a memref input");
+
+ // Extract the aligned pointer from the memref descriptor.
+ rewriter.replaceOp(
+ op, MemRefDescriptor(adaptor.getPtr()).alignedPtr(rewriter, op.getLoc()));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TypeOffsetOpConversion
+//===----------------------------------------------------------------------===//
+
+LogicalResult TypeOffsetOpConversion::matchAndRewrite(
+ ptr::TypeOffsetOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Convert the type attribute.
+ Type type = getTypeConverter()->convertType(op.getElementType());
+ if (!type)
+ return rewriter.notifyMatchFailure(op, "Couldn't convert the type");
+
+ // Convert the result type.
+ Type rTy = getTypeConverter()->convertType(op.getResult().getType());
+ if (!rTy)
+ return rewriter.notifyMatchFailure(op, "Couldn't convert the result type");
+
+ // TODO: Use MLIR's data layout. We don't use it because overall support is
+ // still flaky.
+
+ // Create an LLVM pointer type for the GEP operation.
+ auto ptrTy = LLVM::LLVMPointerType::get(getContext());
+
+ // Create a GEP operation to compute the offset of the type.
+ auto offset =
+ LLVM::GEPOp::create(rewriter, op.getLoc(), ptrTy, type,
+ LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrTy),
+ ArrayRef<LLVM::GEPArg>({LLVM::GEPArg(1)}));
+
+ // Replace the original op with a PtrToIntOp using the computed offset.
+ rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, rTy, offset.getRes());
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ConvertToLLVMPatternInterface implementation
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Implement the interface to convert Ptr to LLVM.
+struct PtrToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
+ using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
+ void loadDependentDialects(MLIRContext *context) const final {
+ context->loadDialect<LLVM::LLVMDialect>();
+ }
+
+ /// Hook for derived dialect interface to provide conversion patterns
+ /// and mark dialect legal for the conversion target.
+ void populateConvertToLLVMConversionPatterns(
+ ConversionTarget &target, LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) const final {
+ ptr::populatePtrToLLVMConversionPatterns(converter, patterns);
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// API
+//===----------------------------------------------------------------------===//
+
+void mlir::ptr::populatePtrToLLVMConversionPatterns(
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ // Add address space conversions.
+ converter.addTypeAttributeConversion(
+ [&](PtrLikeTypeInterface type, ptr::GenericSpaceAttr memorySpace)
+ -> TypeConverter::AttributeConversionResult {
+ if (type.getMemorySpace() != memorySpace)
+ return TypeConverter::AttributeConversionResult::na();
+ return IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0);
+ });
+
+ // Add type conversions.
+ converter.addConversion([&](ptr::PtrType type) -> Type {
+ std::optional<Attribute> maybeAttr =
+ converter.convertTypeAttribute(type, type.getMemorySpace());
+ auto memSpace =
+ maybeAttr ? dyn_cast_or_null<IntegerAttr>(*maybeAttr) : IntegerAttr();
+ if (!memSpace)
+ return {};
+ return LLVM::LLVMPointerType::get(type.getContext(),
+ memSpace.getValue().getSExtValue());
+ });
+
+ // Convert ptr metadata of memref type.
+ converter.addConversion([&](ptr::PtrMetadataType type) -> Type {
+ auto mTy = dyn_cast<MemRefType>(type.getType());
+ if (!mTy)
+ return {};
+ FailureOr<LLVM::LLVMStructType> res =
+ createMemRefMetadataType(mTy, converter);
+ return failed(res) ? Type() : res.value();
+ });
+
+ // Add conversion patterns.
+ patterns.add<FromPtrOpConversion, GetMetadataOpConversion, PtrAddOpConversion,
+ ToPtrOpConversion, TypeOffsetOpConversion>(converter);
+}
+
+void mlir::ptr::registerConvertPtrToLLVMInterface(DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) {
+ dialect->addInterfaces<PtrToLLVMDialectInterface>();
+ });
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 232ddaf6762c4..69a85dbe141ce 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -28,6 +28,7 @@
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
+#include "mlir/Conversion/PtrToLLVM/PtrToLLVM.h"
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
@@ -81,6 +82,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) {
registerConvertMemRefToEmitCInterface(registry);
registerConvertMemRefToLLVMInterface(registry);
registerConvertNVVMToLLVMInterface(registry);
+ ptr::registerConvertPtrToLLVMInterface(registry);
registerConvertOpenMPToLLVMInterface(registry);
registerConvertSCFToEmitCInterface(registry);
ub::registerConvertUBToLLVMInterface(regis...
[truncated]
|
|
Thanks for contributing this!
the IR: # get memref from alloca
%alloca_0 = memref.alloca() : memref<64xi64, #ptr.generic_space>
# get aligned ptr from memref descriptor
%0 = ptr.to_ptr %alloca : memref<64xi64, #ptr.generic_space> -> <#ptr.generic_space>
# do something with the aligned ptr
%1 = call @enif_get_int64(%arg0, %arg1, %0) : (!ptr.ptr<#ptr.generic_space>, i64, !ptr.ptr<#ptr.generic_space>) -> i32
# load the ptr.ptr
%2 = ptr.load %0 : !ptr.ptr<#ptr.generic_space> -> i64# get aligned ptr from memref descriptor
%10 = llvm.extractvalue %9[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
# do something with the aligned ptr
%12 = llvm.call @enif_get_int64(%arg0, %arg1, %10) : (!llvm.ptr, i64, !llvm.ptr) -> i32
# load the ptr.ptr, which is a now a llvm.ptr
%11 = builtin.unrealized_conversion_cast %10 : !llvm.ptr to !ptr.ptr<#ptr.generic_space>
%13 = ptr.load %11 : !ptr.ptr<#ptr.generic_space> -> i64 |
|
Thank you for bringing this up, I'll fix it this weekend. |
This patch adds conversion to LLVM for all existing pointer ops. This is a stop gap measure to allow users to use the
ptrdialect now. In the future some of these conversions will be removed, and added as translations, for exampleptradd.Example: