From fe22487907de4153dd492b2b6d3bcd83d6468993 Mon Sep 17 00:00:00 2001 From: Fabian Mora <6982088+fabianmcg@users.noreply.github.com> Date: Sat, 26 Apr 2025 19:25:54 +0000 Subject: [PATCH 1/7] [mlir][core|ptr] Add `PtrLikeTypeInterface` and casting ops to the `ptr` dialect This patch adds the `PtrLikeTypeInterface` type interface to identify pointer-like types. This interface is defined as: ``` A ptr-like type represents an object storing a memory address. This object is constituted by: - A memory address called the base pointer. The base pointer is an indivisible object. - Optional metadata about the pointer. For example, the size of the memory region associated with the pointer. Furthermore, all ptr-like types have two properties: - The memory space associated with the address held by the pointer. - An optional element type. If the element type is not specified, the pointer is considered opaque. ``` This patch adds this interface to `!ptr.ptr` and the `memref` type. Furthermore, this patch adds necessary ops and type to handle casting between `!ptr.ptr` and ptr-like types. First, it defines the `!ptr.ptr_metadata` type. An opaque type to represent the metadata of a ptr-like type. The rationale behind adding this type, is that at high-level the metadata of a type like `memref` cannot be specified, as its structure is tied to its lowering. The `ptr.get_metadata` operation was added to extract the opaque pointer metadata. The concrete structure of the metadata is only known when the op is lowered. Finally, this patch adds the `ptr.from_ptr` and `ptr.to_ptr` operations. Allowing to cast back and forth between `!ptr.ptr` and ptr-liker types. ```mlir func.func @func(%mr: memref) -> memref { %ptr = ptr.to_ptr %mr : memref -> !ptr.ptr<#ptr.generic_space> %mda = ptr.get_metadata %mr : memref %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref return %res : memref } ``` --- .../include/mlir/Dialect/Ptr/IR/PtrDialect.td | 49 +++++++++ mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 99 +++++++++++++++++++ mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 49 +++++++++ mlir/include/mlir/IR/BuiltinTypes.h | 18 +++- mlir/include/mlir/IR/BuiltinTypes.td | 2 + mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 75 ++++++++++++++ mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp | 12 +++ mlir/lib/IR/BuiltinTypes.cpp | 14 +++ mlir/test/Dialect/Ptr/canonicalize.mlir | 58 +++++++++++ mlir/test/Dialect/Ptr/invalid.mlir | 33 +++++++ mlir/test/Dialect/Ptr/ops.mlir | 10 ++ 11 files changed, 418 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Dialect/Ptr/invalid.mlir diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td index 73b2a0857cef3..6631b338db199 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td @@ -37,6 +37,7 @@ class Ptr_Type traits = []> def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [ MemRefElementTypeInterface, + PtrLikeTypeInterface, VectorElementTypeInterface, DeclareTypeInterfaceMethods ]; + let extraClassDeclaration = [{ + // `PtrLikeTypeInterface` interface methods. + /// Returns `Type()` as this pointer type is opaque. + Type getElementType() const { + return Type(); + } + /// Clones the pointer with specified memory space or returns failure + /// if an `elementType` was specified or if the memory space doesn't + /// implement `MemorySpaceAttrInterface`. + FailureOr clonePtrWith(Attribute memorySpace, + std::optional elementType) const { + if (elementType) + return failure(); + if (auto ms = dyn_cast(memorySpace)) + return cast(get(ms)); + return failure(); + } + /// `!ptr.ptr` types are seen as ptr-like objects with no metadata. + bool hasPtrMetadata() const { + return false; + } + }]; +} + +def Ptr_PtrMetadata : Ptr_Type<"PtrMetadata", "ptr_metadata"> { + let summary = "Pointer metadata type"; + let description = [{ + The `ptr_metadata` type represents an opaque-view of the metadata associated + with a `ptr-like` object type. + It's an error to get a `ptr_metadata` using `ptr-like` type with no + metadata. + + Example: + + ```mlir + // The metadata associated with a `memref` type. + !ptr.ptr_metadata> + ``` + }]; + let parameters = (ins "PtrLikeTypeInterface":$type); + let assemblyFormat = "`<` $type `>`"; + let builders = [ + TypeBuilderWithInferredContext<(ins + "PtrLikeTypeInterface":$ptrLike), [{ + return $_get(ptrLike.getContext(), ptrLike); + }]> + ]; + let genVerifyDecl = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index 791b95ad3559e..8ad475c41c8d3 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -17,6 +17,75 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/IR/OpAsmInterface.td" +//===----------------------------------------------------------------------===// +// FromPtrOp +//===----------------------------------------------------------------------===// + +def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [ + Pure, OptionalTypesMatchWith<"metadata type", "result", "metadata", + "PtrMetadataType::get(cast($_self))"> + ]> { + let summary = "Casts a `!ptr.ptr` value to a ptr-like value."; + let description = [{ + The `from_ptr` operation casts a `ptr` value to a ptr-like object. It's + important to note that: + - The ptr-like object cannot be a `!ptr.ptr`. + - The memory-space of both the `ptr` and ptr-like object must match. + - The cast is side-effect free. + + If the ptr-like object type has metadata, then the operation expects the + metadata as an argument or expects that the flag `trivial_metadata` is set. + If `trivial_metadata` is set, then it is assumed that the metadata can be + reconstructed statically from the pointer-like type. + + Example: + + ```mlir + %typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<0> -> !my.ptr + %memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<0> -> memref + %memref = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<0> -> memref + ``` + }]; + + let arguments = (ins Ptr_PtrType:$ptr, + Optional:$metadata, + UnitProp:$hasTrivialMetadata); + let results = (outs PtrLikeTypeInterface:$result); + let assemblyFormat = [{ + $ptr (`metadata` $metadata^)? (`trivial_metadata` $hasTrivialMetadata^)? + attr-dict `:` type($ptr) `->` type($result) + }]; + let hasFolder = 1; + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// GetMetadataOp +//===----------------------------------------------------------------------===// + +def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [ + Pure, TypesMatchWith<"metadata type", "ptr", "result", + "PtrMetadataType::get(cast($_self))"> + ]> { + let summary = "SSA value representing pointer metadata."; + let description = [{ + The `get_metadata` operation produces an opaque value that encodes the + metadata of the ptr-like type. + + Example: + + ```mlir + %metadata = ptr.get_metadata %memref : memref + ``` + }]; + + let arguments = (ins PtrLikeTypeInterface:$ptr); + let results = (outs Ptr_PtrMetadata:$result); + let assemblyFormat = [{ + $ptr attr-dict `:` type($ptr) + }]; +} + //===----------------------------------------------------------------------===// // PtrAddOp //===----------------------------------------------------------------------===// @@ -52,6 +121,36 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [ }]; } +//===----------------------------------------------------------------------===// +// ToPtrOp +//===----------------------------------------------------------------------===// + +def Ptr_ToPtrOp : Pointer_Op<"to_ptr", [Pure]> { + let summary = "Casts a ptr-like value to a `!ptr.ptr` value."; + let description = [{ + The `to_ptr` operation casts a ptr-like object to a `!ptr.ptr`. It's + important to note that: + - The ptr-like object cannot be a `!ptr.ptr`. + - The memory-space of both the `ptr` and ptr-like object must match. + - The cast is side-effect free. + + Example: + + ```mlir + %ptr0 = ptr.to_ptr %my_ptr : !my.ptr -> !ptr.ptr<0> + %ptr1 = ptr.to_ptr %memref : memref -> !ptr.ptr<0> + ``` + }]; + + let arguments = (ins PtrLikeTypeInterface:$ptr); + let results = (outs Ptr_PtrType:$result); + let assemblyFormat = [{ + $ptr attr-dict `:` type($ptr) `->` type($result) + }]; + let hasFolder = 1; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // TypeOffsetOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td index 4a4f818b46c57..d058f6c4d9651 100644 --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -110,6 +110,55 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> { }]; } +//===----------------------------------------------------------------------===// +// PtrLikeTypeInterface +//===----------------------------------------------------------------------===// + +def PtrLikeTypeInterface : TypeInterface<"PtrLikeTypeInterface"> { + let cppNamespace = "::mlir"; + let description = [{ + A ptr-like type represents an object storing a memory address. This object + is constituted by: + - A memory address called the base pointer. The base pointer is an + indivisible object. + - Optional metadata about the pointer. For example, the size of the memory + region associated with the pointer. + + Furthermore, all ptr-like types have two properties: + - The memory space associated with the address held by the pointer. + - An optional element type. If the element type is not specified, the + pointer is considered opaque. + }]; + let methods = [ + InterfaceMethod<[{ + Returns the memory space of this ptr-like type. + }], + "::mlir::Attribute", "getMemorySpace">, + InterfaceMethod<[{ + Returns the element type of this ptr-like type. Note: this method can + return `::mlir::Type()`, in which case the pointer is considered opaque. + }], + "::mlir::Type", "getElementType">, + InterfaceMethod<[{ + Returns whether this ptr-like type has non-empty metadata. + }], + "bool", "hasPtrMetadata">, + InterfaceMethod<[{ + Returns a clone of this type with the given memory space and element type, + or `failure` if the type cannot be cloned with the specified arguments. + If the pointer is opaque and `elementType` is not `std::nullopt` the + method will return `failure`. + + If no `elementType` is provided and ptr is not opaque, the `elementType` + of this type is used. + }], + "::llvm::FailureOr<::mlir::PtrLikeTypeInterface>", "clonePtrWith", (ins + "::mlir::Attribute":$memorySpace, + "::std::optional<::mlir::Type>":$elementType + )> + ]; +} + //===----------------------------------------------------------------------===// // ShapedType //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index df1e02732617d..86ec5c43970b1 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -99,7 +99,9 @@ class TensorType : public Type, public ShapedType::Trait { /// Note: This class attaches the ShapedType trait to act as a mixin to /// provide many useful utility functions. This inheritance has no effect /// on derived memref types. -class BaseMemRefType : public Type, public ShapedType::Trait { +class BaseMemRefType : public Type, + public PtrLikeTypeInterface::Trait, + public ShapedType::Trait { public: using Type::Type; @@ -117,6 +119,12 @@ class BaseMemRefType : public Type, public ShapedType::Trait { BaseMemRefType cloneWith(std::optional> shape, Type elementType) const; + /// Clone this type with the given memory space and element type. If the + /// provided element type is `std::nullopt`, the current element type of the + /// type is used. + FailureOr + clonePtrWith(Attribute memorySpace, std::optional elementType) const; + // Make sure that base class overloads are visible. using ShapedType::Trait::clone; @@ -141,8 +149,16 @@ class BaseMemRefType : public Type, public ShapedType::Trait { /// New `Attribute getMemorySpace()` method should be used instead. unsigned getMemorySpaceAsInt() const; + /// Returns that this ptr-like object has non-empty ptr metadata. + bool hasPtrMetadata() const { return true; } + /// Allow implicit conversion to ShapedType. operator ShapedType() const { return llvm::cast(*this); } + + /// Allow implicit conversion to PtrLikeTypeInterface. + operator PtrLikeTypeInterface() const { + return llvm::cast(*this); + } }; } // namespace mlir diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 771de01fc8d5d..9ad24e45c8315 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -562,6 +562,7 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer", //===----------------------------------------------------------------------===// def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [ + PtrLikeTypeInterface, ShapedTypeInterface ], "BaseMemRefType"> { let summary = "Shaped reference to a region of memory"; @@ -1143,6 +1144,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> { //===----------------------------------------------------------------------===// def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [ + PtrLikeTypeInterface, ShapedTypeInterface ], "BaseMemRefType"> { let summary = "Shaped reference, with unknown rank, to a region of memory"; diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index c21783011452f..80fd7617c9354 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -41,6 +41,54 @@ void PtrDialect::initialize() { >(); } +//===----------------------------------------------------------------------===// +// FromPtrOp +//===----------------------------------------------------------------------===// + +OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) { + // Fold the pattern: + // %ptr = ptr.to_ptr %v : type -> ptr + // (%mda = ptr.get_metadata %v : type)? + // %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type + // To: + // %val -> %v + auto toPtr = dyn_cast_or_null(getPtr().getDefiningOp()); + // Cannot fold if it's not a `to_ptr` op or the initial and final types are + // different. + if (!toPtr || toPtr.getPtr().getType() != getType()) + return nullptr; + Value md = getMetadata(); + if (!md) + return toPtr.getPtr(); + // Fold if the metadata can be verified to be equal. + if (auto mdOp = dyn_cast_or_null(md.getDefiningOp()); + mdOp && mdOp.getPtr() == toPtr.getPtr()) + return toPtr.getPtr(); + return nullptr; +} + +LogicalResult FromPtrOp::verify() { + if (isa(getType())) + return emitError() << "the result type cannot be `!ptr.ptr`"; + if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) { + return emitError() + << "expected the input and output to have the same memory space"; + } + bool hasMD = getMetadata() != Value(); + bool hasTrivialMD = getHasTrivialMetadata(); + if (hasMD && hasTrivialMD) { + return emitError() << "expected either a metadata argument or the " + "`trivial_metadata` flag, not both"; + } + if (getType().hasPtrMetadata() && !(hasMD || hasTrivialMD)) { + return emitError() << "expected either a metadata argument or the " + "`trivial_metadata` flag to be set"; + } + if (!getType().hasPtrMetadata() && (hasMD || hasTrivialMD)) + return emitError() << "expected no metadata specification"; + return success(); +} + //===----------------------------------------------------------------------===// // PtrAddOp //===----------------------------------------------------------------------===// @@ -55,6 +103,33 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// ToPtrOp +//===----------------------------------------------------------------------===// + +OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) { + // Fold the pattern: + // %val = ptr.from_ptr %p (metadata ...)? : ptr -> type + // %ptr = ptr.to_ptr %val : type -> ptr + // To: + // %ptr -> %p + auto fromPtr = dyn_cast_or_null(getPtr().getDefiningOp()); + // Cannot fold if it's not a `from_ptr` op. + if (!fromPtr) + return nullptr; + return fromPtr.getPtr(); +} + +LogicalResult ToPtrOp::verify() { + if (isa(getPtr().getType())) + return emitError() << "the input value cannot be of type `!ptr.ptr`"; + if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) { + return emitError() + << "expected the input and output to have the same memory space"; + } + return success(); +} + //===----------------------------------------------------------------------===// // TypeOffsetOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp index cab9ca11e679e..7ad2a6bc4c80b 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp @@ -151,3 +151,15 @@ LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries, } return success(); } + +//===----------------------------------------------------------------------===// +// Pointer metadata +//===----------------------------------------------------------------------===// + +LogicalResult +PtrMetadataType::verify(function_ref emitError, + PtrLikeTypeInterface type) { + if (!type.hasPtrMetadata()) + return emitError() << "the ptr-like type has no metadata"; + return success(); +} diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index d47e360e9dc13..97bab479c79bf 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -376,6 +376,20 @@ BaseMemRefType BaseMemRefType::cloneWith(std::optional> shape, return builder; } +FailureOr +BaseMemRefType::clonePtrWith(Attribute memorySpace, + std::optional elementType) const { + Type eTy = elementType ? *elementType : getElementType(); + if (llvm::dyn_cast(*this)) + return cast( + UnrankedMemRefType::get(eTy, memorySpace)); + + MemRefType::Builder builder(llvm::cast(*this)); + builder.setElementType(eTy); + builder.setMemorySpace(memorySpace); + return cast(static_cast(builder)); +} + MemRefType BaseMemRefType::clone(::llvm::ArrayRef shape, Type elementType) const { return ::llvm::cast(cloneWith(shape, elementType)); diff --git a/mlir/test/Dialect/Ptr/canonicalize.mlir b/mlir/test/Dialect/Ptr/canonicalize.mlir index ad363d554f247..837f364242beb 100644 --- a/mlir/test/Dialect/Ptr/canonicalize.mlir +++ b/mlir/test/Dialect/Ptr/canonicalize.mlir @@ -13,3 +13,61 @@ func.func @zero_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.gene %res0 = ptr.ptr_add %ptr, %off : !ptr.ptr<#ptr.generic_space>, index return %res0 : !ptr.ptr<#ptr.generic_space> } + +/// Tests the the `from_ptr` folder. +// CHECK-LABEL: @test_from_ptr_0 +// CHECK-SAME: (%[[MEM_REF:.*]]: memref) +func.func @test_from_ptr_0(%mr: memref) -> memref { + // CHECK-NOT: ptr.to_ptr + // CHECK-NOT: ptr.get_metadata + // CHECK-NOT: ptr.from_ptr + // CHECK: return %[[MEM_REF]] + %ptr = ptr.to_ptr %mr : memref -> !ptr.ptr<#ptr.generic_space> + %mda = ptr.get_metadata %mr : memref + %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref + return %res : memref +} + +// CHECK-LABEL: @test_from_ptr_1 +// CHECK-SAME: (%[[MEM_REF:.*]]: memref) +func.func @test_from_ptr_1(%mr: memref) -> memref { + // CHECK-NOT: ptr.to_ptr + // CHECK-NOT: ptr.from_ptr + // CHECK: return %[[MEM_REF]] + %ptr = ptr.to_ptr %mr : memref -> !ptr.ptr<#ptr.generic_space> + %res = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref + return %res : memref +} + +/// Check that the ops cannot be folded because the metadata cannot be guaranteed to be the same. +// CHECK-LABEL: @test_from_ptr_2 +func.func @test_from_ptr_2(%mr: memref, %md: !ptr.ptr_metadata>) -> memref { + // CHECK: ptr.to_ptr + // CHECK: ptr.from_ptr + %ptr = ptr.to_ptr %mr : memref -> !ptr.ptr<#ptr.generic_space> + %res = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref + return %res : memref +} + +/// Tests the the `to_ptr` folder. +// CHECK-LABEL: @test_to_ptr_0 +// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space> +func.func @test_to_ptr_0(%ptr: !ptr.ptr<#ptr.generic_space>, %md: !ptr.ptr_metadata>) -> !ptr.ptr<#ptr.generic_space> { + // CHECK: return %[[PTR]] + // CHECK-NOT: ptr.from_ptr + // CHECK-NOT: ptr.to_ptr + %mrf = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref + %res = ptr.to_ptr %mrf : memref -> !ptr.ptr<#ptr.generic_space> + return %res : !ptr.ptr<#ptr.generic_space> +} + +// CHECK-LABEL: @test_to_ptr_1 +// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>) +func.func @test_to_ptr_1(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.generic_space> { + // CHECK-NOT: ptr.from_ptr + // CHECK-NOT: ptr.to_ptr + // CHECK: return %[[PTR]] + %mrf = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref + %res = ptr.to_ptr %mrf : memref -> !ptr.ptr<#ptr.generic_space> + return %res : !ptr.ptr<#ptr.generic_space> +} diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir new file mode 100644 index 0000000000000..e776e0ee04f90 --- /dev/null +++ b/mlir/test/Dialect/Ptr/invalid.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s + +/// Test `to_ptr` verifiers. +func.func @invalid_to_ptr(%v: memref) { + // expected-error@+1 {{expected the input and output to have the same memory space}} + %r = ptr.to_ptr %v : memref -> !ptr.ptr<#ptr.generic_space> + return +} + +// ----- + +func.func @invalid_to_ptr(%v: !ptr.ptr<#ptr.generic_space>) { + // expected-error@+1 {{the input value cannot be of type `!ptr.ptr`}} + %r = ptr.to_ptr %v : !ptr.ptr<#ptr.generic_space> -> !ptr.ptr<#ptr.generic_space> + return +} + +// ----- + +/// Test `from_ptr` verifiers. +func.func @invalid_from_ptr(%v: !ptr.ptr<#ptr.generic_space>) { + // expected-error@+1 {{expected either a metadata argument or the `trivial_metadata` flag to be set}} + %r = ptr.from_ptr %v : !ptr.ptr<#ptr.generic_space> -> memref + return +} + +// ----- + +func.func @invalid_from_ptr(%v: !ptr.ptr<#ptr.generic_space>, %m: !ptr.ptr_metadata>) { + // expected-error@+1 {{expected either a metadata argument or the `trivial_metadata` flag, not both}} + %r = ptr.from_ptr %v metadata %m trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref + return +} diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir index d763ea221944b..74bff25b4f3e1 100644 --- a/mlir/test/Dialect/Ptr/ops.mlir +++ b/mlir/test/Dialect/Ptr/ops.mlir @@ -17,3 +17,13 @@ func.func @ptr_add_type_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<# %res3 = ptr.ptr_add inbounds %ptr, %off : !ptr.ptr<#ptr.generic_space>, index return %res : !ptr.ptr<#ptr.generic_space> } + +/// Check cast ops assembly. +// CHECK-LABEL: @cast_ops +func.func @cast_ops(%mr: memref) -> memref { + %ptr = ptr.to_ptr %mr : memref -> !ptr.ptr<#ptr.generic_space> + %mda = ptr.get_metadata %mr : memref + %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref + %mr0 = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref + return %res : memref +} From afac5f4118573633bc89451c36ee05144cd4baf2 Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Sun, 27 Apr 2025 07:29:15 -0400 Subject: [PATCH 2/7] Update mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td Co-authored-by: Mehdi Amini --- mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index 8ad475c41c8d3..55cc47a41d03b 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -31,7 +31,7 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [ important to note that: - The ptr-like object cannot be a `!ptr.ptr`. - The memory-space of both the `ptr` and ptr-like object must match. - - The cast is side-effect free. + - The cast is Pure (no UB and side-effect free). If the ptr-like object type has metadata, then the operation expects the metadata as an argument or expects that the flag `trivial_metadata` is set. From d9fd27ec7f65eda27587cc602d43da762c475fd5 Mon Sep 17 00:00:00 2001 From: Fabian Mora <6982088+fabianmcg@users.noreply.github.com> Date: Sun, 27 Apr 2025 13:33:46 +0000 Subject: [PATCH 3/7] add tests for chains of casts --- mlir/test/Dialect/Ptr/canonicalize.mlir | 48 +++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/mlir/test/Dialect/Ptr/canonicalize.mlir b/mlir/test/Dialect/Ptr/canonicalize.mlir index 837f364242beb..2b9c8489f352e 100644 --- a/mlir/test/Dialect/Ptr/canonicalize.mlir +++ b/mlir/test/Dialect/Ptr/canonicalize.mlir @@ -49,6 +49,21 @@ func.func @test_from_ptr_2(%mr: memref, %md: !ptr.ptr_m return %res : memref } +// Check the folding of `to_ptr -> from_ptr` chains. +// CHECK-LABEL: @test_from_ptr_3 +// CHECK-SAME: (%[[MEM_REF:.*]]: memref) +func.func @test_from_ptr_3(%mr0: memref) -> memref { + // CHECK-NOT: ptr.to_ptr + // CHECK-NOT: ptr.from_ptr + // CHECK: return %[[MEM_REF]] + %mda = ptr.get_metadata %mr0 : memref + %ptr0 = ptr.to_ptr %mr0 : memref -> !ptr.ptr<#ptr.generic_space> + %mrf0 = ptr.from_ptr %ptr0 metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref + %ptr1 = ptr.to_ptr %mrf0 : memref -> !ptr.ptr<#ptr.generic_space> + %mrf1 = ptr.from_ptr %ptr1 metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref + return %mrf1 : memref +} + /// Tests the the `to_ptr` folder. // CHECK-LABEL: @test_to_ptr_0 // CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space> @@ -71,3 +86,36 @@ func.func @test_to_ptr_1(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.ge %res = ptr.to_ptr %mrf : memref -> !ptr.ptr<#ptr.generic_space> return %res : !ptr.ptr<#ptr.generic_space> } + +// Check the folding of `from_ptr -> to_ptr` chains. +// CHECK-LABEL: @test_to_ptr_2 +// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space> +func.func @test_to_ptr_2(%ptr0: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.generic_space> { + // CHECK-NOT: ptr.from_ptr + // CHECK-NOT: ptr.to_ptr + // CHECK: return %[[PTR]] + %mrf0 = ptr.from_ptr %ptr0 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref + %ptr1 = ptr.to_ptr %mrf0 : memref -> !ptr.ptr<#ptr.generic_space> + %mrf1 = ptr.from_ptr %ptr1 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref + %ptr2 = ptr.to_ptr %mrf1 : memref -> !ptr.ptr<#ptr.generic_space> + %mrf2 = ptr.from_ptr %ptr2 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref + %res = ptr.to_ptr %mrf2 : memref -> !ptr.ptr<#ptr.generic_space> + return %res : !ptr.ptr<#ptr.generic_space> +} + +// Check the folding of chains with different metadata. +// CHECK-LABEL: @test_cast_chain_folding +// CHECK-SAME: (%[[MEM_REF:.*]]: memref +func.func @test_cast_chain_folding(%mr: memref, %md: !ptr.ptr_metadata>) -> memref { + // CHECK-NOT: ptr.to_ptr + // CHECK-NOT: ptr.from_ptr + // CHECK: return %[[MEM_REF]] + %ptr1 = ptr.to_ptr %mr : memref -> !ptr.ptr<#ptr.generic_space> + %memrefWithOtherMd = ptr.from_ptr %ptr1 metadata %md : !ptr.ptr<#ptr.generic_space> -> memref + %ptr = ptr.to_ptr %memrefWithOtherMd : memref -> !ptr.ptr<#ptr.generic_space> + %mda = ptr.get_metadata %mr : memref + // The chain can be folded because: the ptr always has the same value because + // `to_ptr` is a loss-less cast and %mda comes from the original memref. + %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref + return %res : memref +} From 634d03ba963e01422a96b7046b44d37fdfbd4b12 Mon Sep 17 00:00:00 2001 From: Fabian Mora <6982088+fabianmcg@users.noreply.github.com> Date: Sat, 10 May 2025 13:34:46 +0000 Subject: [PATCH 4/7] make folders work on cast sequences --- mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 52 +++++++++++++++++--------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index 80fd7617c9354..c0310446e3cea 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -52,19 +52,28 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) { // %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type // To: // %val -> %v - auto toPtr = dyn_cast_or_null(getPtr().getDefiningOp()); - // Cannot fold if it's not a `to_ptr` op or the initial and final types are - // different. - if (!toPtr || toPtr.getPtr().getType() != getType()) - return nullptr; - Value md = getMetadata(); - if (!md) - return toPtr.getPtr(); - // Fold if the metadata can be verified to be equal. - if (auto mdOp = dyn_cast_or_null(md.getDefiningOp()); - mdOp && mdOp.getPtr() == toPtr.getPtr()) - return toPtr.getPtr(); - return nullptr; + Value ptrLike; + FromPtrOp fromPtr = *this; + while (fromPtr != nullptr) { + auto toPtr = dyn_cast_or_null(fromPtr.getPtr().getDefiningOp()); + // Cannot fold if it's not a `to_ptr` op or the initial and final types are + // different. + if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType()) + return ptrLike; + Value md = fromPtr.getMetadata(); + // If there's no metadata in the op, either the cast never requires metadata + // or the op has the trivial metadata flag set, therefore fold. + if (!md) + ptrLike = toPtr.getPtr(); + // Fold if the metadata can be verified to be equal. + else if (auto mdOp = dyn_cast_or_null(md.getDefiningOp()); + mdOp && mdOp.getPtr() == toPtr.getPtr()) + ptrLike = toPtr.getPtr(); + // Check for a sequence of casts. + fromPtr = dyn_cast_or_null(ptrLike ? ptrLike.getDefiningOp() + : nullptr); + } + return ptrLike; } LogicalResult FromPtrOp::verify() { @@ -113,11 +122,18 @@ OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) { // %ptr = ptr.to_ptr %val : type -> ptr // To: // %ptr -> %p - auto fromPtr = dyn_cast_or_null(getPtr().getDefiningOp()); - // Cannot fold if it's not a `from_ptr` op. - if (!fromPtr) - return nullptr; - return fromPtr.getPtr(); + Value ptr; + ToPtrOp toPtr = *this; + while (toPtr != nullptr) { + auto fromPtr = dyn_cast_or_null(toPtr.getPtr().getDefiningOp()); + // Cannot fold if it's not a `from_ptr` op. + if (!fromPtr) + return ptr; + ptr = fromPtr.getPtr(); + // Check for chains of casts. + toPtr = dyn_cast_or_null(ptr.getDefiningOp()); + } + return ptr; } LogicalResult ToPtrOp::verify() { From f090320893d026975f30edfbd3693a286be4da59 Mon Sep 17 00:00:00 2001 From: Fabian Mora <6982088+fabianmcg@users.noreply.github.com> Date: Wed, 14 May 2025 16:44:14 +0000 Subject: [PATCH 5/7] remove trivial_metadata flag from from_ptr op --- mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 17 +++++++---------- mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 15 +-------------- mlir/test/Dialect/Ptr/canonicalize.mlir | 10 +++++----- mlir/test/Dialect/Ptr/invalid.mlir | 17 ----------------- mlir/test/Dialect/Ptr/ops.mlir | 2 +- 5 files changed, 14 insertions(+), 47 deletions(-) diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index 55cc47a41d03b..37eb91fa6a338 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -33,27 +33,24 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [ - The memory-space of both the `ptr` and ptr-like object must match. - The cast is Pure (no UB and side-effect free). - If the ptr-like object type has metadata, then the operation expects the - metadata as an argument or expects that the flag `trivial_metadata` is set. - If `trivial_metadata` is set, then it is assumed that the metadata can be - reconstructed statically from the pointer-like type. + The optional `metadata` operand exists to provide any ptr-like metadata + that might be required to perform the cast. Example: ```mlir %typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<0> -> !my.ptr %memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<0> -> memref - %memref = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<0> -> memref + + // Cast the `%ptr` to a memref without utilizing metadata. + %memref = ptr.from_ptr %ptr : !ptr.ptr<0> -> memref ``` }]; - let arguments = (ins Ptr_PtrType:$ptr, - Optional:$metadata, - UnitProp:$hasTrivialMetadata); + let arguments = (ins Ptr_PtrType:$ptr, Optional:$metadata); let results = (outs PtrLikeTypeInterface:$result); let assemblyFormat = [{ - $ptr (`metadata` $metadata^)? (`trivial_metadata` $hasTrivialMetadata^)? - attr-dict `:` type($ptr) `->` type($result) + $ptr (`metadata` $metadata^)? attr-dict `:` type($ptr) `->` type($result) }]; let hasFolder = 1; let hasVerifier = 1; diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index c0310446e3cea..ffa924b20ab59 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -61,8 +61,7 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) { if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType()) return ptrLike; Value md = fromPtr.getMetadata(); - // If there's no metadata in the op, either the cast never requires metadata - // or the op has the trivial metadata flag set, therefore fold. + // If there's no metadata in the op fold the op. if (!md) ptrLike = toPtr.getPtr(); // Fold if the metadata can be verified to be equal. @@ -83,18 +82,6 @@ LogicalResult FromPtrOp::verify() { return emitError() << "expected the input and output to have the same memory space"; } - bool hasMD = getMetadata() != Value(); - bool hasTrivialMD = getHasTrivialMetadata(); - if (hasMD && hasTrivialMD) { - return emitError() << "expected either a metadata argument or the " - "`trivial_metadata` flag, not both"; - } - if (getType().hasPtrMetadata() && !(hasMD || hasTrivialMD)) { - return emitError() << "expected either a metadata argument or the " - "`trivial_metadata` flag to be set"; - } - if (!getType().hasPtrMetadata() && (hasMD || hasTrivialMD)) - return emitError() << "expected no metadata specification"; return success(); } diff --git a/mlir/test/Dialect/Ptr/canonicalize.mlir b/mlir/test/Dialect/Ptr/canonicalize.mlir index 2b9c8489f352e..dfc679acb2ed4 100644 --- a/mlir/test/Dialect/Ptr/canonicalize.mlir +++ b/mlir/test/Dialect/Ptr/canonicalize.mlir @@ -35,7 +35,7 @@ func.func @test_from_ptr_1(%mr: memref) -> memref -> !ptr.ptr<#ptr.generic_space> - %res = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref + %res = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref return %res : memref } @@ -82,7 +82,7 @@ func.func @test_to_ptr_1(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.ge // CHECK-NOT: ptr.from_ptr // CHECK-NOT: ptr.to_ptr // CHECK: return %[[PTR]] - %mrf = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref + %mrf = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref %res = ptr.to_ptr %mrf : memref -> !ptr.ptr<#ptr.generic_space> return %res : !ptr.ptr<#ptr.generic_space> } @@ -94,11 +94,11 @@ func.func @test_to_ptr_2(%ptr0: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.g // CHECK-NOT: ptr.from_ptr // CHECK-NOT: ptr.to_ptr // CHECK: return %[[PTR]] - %mrf0 = ptr.from_ptr %ptr0 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref + %mrf0 = ptr.from_ptr %ptr0 : !ptr.ptr<#ptr.generic_space> -> memref %ptr1 = ptr.to_ptr %mrf0 : memref -> !ptr.ptr<#ptr.generic_space> - %mrf1 = ptr.from_ptr %ptr1 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref + %mrf1 = ptr.from_ptr %ptr1 : !ptr.ptr<#ptr.generic_space> -> memref %ptr2 = ptr.to_ptr %mrf1 : memref -> !ptr.ptr<#ptr.generic_space> - %mrf2 = ptr.from_ptr %ptr2 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref + %mrf2 = ptr.from_ptr %ptr2 : !ptr.ptr<#ptr.generic_space> -> memref %res = ptr.to_ptr %mrf2 : memref -> !ptr.ptr<#ptr.generic_space> return %res : !ptr.ptr<#ptr.generic_space> } diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir index e776e0ee04f90..19fd715e5bba6 100644 --- a/mlir/test/Dialect/Ptr/invalid.mlir +++ b/mlir/test/Dialect/Ptr/invalid.mlir @@ -14,20 +14,3 @@ func.func @invalid_to_ptr(%v: !ptr.ptr<#ptr.generic_space>) { %r = ptr.to_ptr %v : !ptr.ptr<#ptr.generic_space> -> !ptr.ptr<#ptr.generic_space> return } - -// ----- - -/// Test `from_ptr` verifiers. -func.func @invalid_from_ptr(%v: !ptr.ptr<#ptr.generic_space>) { - // expected-error@+1 {{expected either a metadata argument or the `trivial_metadata` flag to be set}} - %r = ptr.from_ptr %v : !ptr.ptr<#ptr.generic_space> -> memref - return -} - -// ----- - -func.func @invalid_from_ptr(%v: !ptr.ptr<#ptr.generic_space>, %m: !ptr.ptr_metadata>) { - // expected-error@+1 {{expected either a metadata argument or the `trivial_metadata` flag, not both}} - %r = ptr.from_ptr %v metadata %m trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref - return -} diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir index 74bff25b4f3e1..eed3272d98da9 100644 --- a/mlir/test/Dialect/Ptr/ops.mlir +++ b/mlir/test/Dialect/Ptr/ops.mlir @@ -24,6 +24,6 @@ func.func @cast_ops(%mr: memref) -> memref -> !ptr.ptr<#ptr.generic_space> %mda = ptr.get_metadata %mr : memref %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref - %mr0 = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref + %mr0 = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref return %res : memref } From 5b1a2ad53b9ca95400531bd6ce930eda88e83446 Mon Sep 17 00:00:00 2001 From: Fabian Mora <6982088+fabianmcg@users.noreply.github.com> Date: Fri, 6 Jun 2025 22:52:34 +0000 Subject: [PATCH 6/7] address reviewer comments --- mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td | 5 +++-- mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | 14 +++++++------- mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 8 ++++++-- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td index 6631b338db199..7407d74ce3a87 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td @@ -93,8 +93,9 @@ def Ptr_PtrMetadata : Ptr_Type<"PtrMetadata", "ptr_metadata"> { let description = [{ The `ptr_metadata` type represents an opaque-view of the metadata associated with a `ptr-like` object type. - It's an error to get a `ptr_metadata` using `ptr-like` type with no - metadata. + + Note: It's a verification error to construct a `ptr_metadata` type using a + `ptr-like` type with no metadata. Example: diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index 37eb91fa6a338..1523762efc18f 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -39,11 +39,11 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [ Example: ```mlir - %typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<0> -> !my.ptr - %memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<0> -> memref + %typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> !my.ptr + %memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref // Cast the `%ptr` to a memref without utilizing metadata. - %memref = ptr.from_ptr %ptr : !ptr.ptr<0> -> memref + %memref = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref ``` }]; @@ -98,8 +98,8 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [ Example: ```mlir - %x_off = ptr.ptr_add %x, %off : !ptr.ptr<0>, i32 - %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<0>, i32 + %x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32 + %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32 ``` }]; @@ -134,8 +134,8 @@ def Ptr_ToPtrOp : Pointer_Op<"to_ptr", [Pure]> { Example: ```mlir - %ptr0 = ptr.to_ptr %my_ptr : !my.ptr -> !ptr.ptr<0> - %ptr1 = ptr.to_ptr %memref : memref -> !ptr.ptr<0> + %ptr0 = ptr.to_ptr %my_ptr : !my.ptr -> !ptr.ptr<#ptr.generic_space> + %ptr1 = ptr.to_ptr %memref : memref -> !ptr.ptr<#ptr.generic_space> ``` }]; diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td index d058f6c4d9651..367aeb6ac512b 100644 --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -119,8 +119,12 @@ def PtrLikeTypeInterface : TypeInterface<"PtrLikeTypeInterface"> { let description = [{ A ptr-like type represents an object storing a memory address. This object is constituted by: - - A memory address called the base pointer. The base pointer is an - indivisible object. + - A memory address called the base pointer. This pointer is treated as a + bag of bits without any assumed structure. The bit-width of the base + pointer must be a compile-time constant. However, the bit-width may remain + opaque or unavailable during transformations that do not depend on the + base pointer. Finally, it is considered indivisible in the sense that as + a `PtrLikeTypeInterface` value, it has no metadata. - Optional metadata about the pointer. For example, the size of the memory region associated with the pointer. From 105bcec7d423a2863c046397518d4873316f04fc Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Mon, 16 Jun 2025 19:18:00 +0000 Subject: [PATCH 7/7] address reviewer comments and rebase --- mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 14 ++++++++------ mlir/test/Dialect/Ptr/canonicalize.mlir | 6 +++--- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index ffa924b20ab59..c488144508128 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -61,13 +61,15 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) { if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType()) return ptrLike; Value md = fromPtr.getMetadata(); - // If there's no metadata in the op fold the op. - if (!md) - ptrLike = toPtr.getPtr(); - // Fold if the metadata can be verified to be equal. - else if (auto mdOp = dyn_cast_or_null(md.getDefiningOp()); - mdOp && mdOp.getPtr() == toPtr.getPtr()) + // If the type has trivial metadata fold. + if (!fromPtr.getType().hasPtrMetadata()) { ptrLike = toPtr.getPtr(); + } else if (md) { + // Fold if the metadata can be verified to be equal. + if (auto mdOp = dyn_cast_or_null(md.getDefiningOp()); + mdOp && mdOp.getPtr() == toPtr.getPtr()) + ptrLike = toPtr.getPtr(); + } // Check for a sequence of casts. fromPtr = dyn_cast_or_null(ptrLike ? ptrLike.getDefiningOp() : nullptr); diff --git a/mlir/test/Dialect/Ptr/canonicalize.mlir b/mlir/test/Dialect/Ptr/canonicalize.mlir index dfc679acb2ed4..e50cd1b76caf3 100644 --- a/mlir/test/Dialect/Ptr/canonicalize.mlir +++ b/mlir/test/Dialect/Ptr/canonicalize.mlir @@ -28,12 +28,12 @@ func.func @test_from_ptr_0(%mr: memref) -> memref } +/// Check the op doesn't fold because folding a ptr-type with metadata requires knowing the origin of the metadata. // CHECK-LABEL: @test_from_ptr_1 // CHECK-SAME: (%[[MEM_REF:.*]]: memref) func.func @test_from_ptr_1(%mr: memref) -> memref { - // CHECK-NOT: ptr.to_ptr - // CHECK-NOT: ptr.from_ptr - // CHECK: return %[[MEM_REF]] + // CHECK: ptr.to_ptr + // CHECK: ptr.from_ptr %ptr = ptr.to_ptr %mr : memref -> !ptr.ptr<#ptr.generic_space> %res = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref return %res : memref