diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td index 73b2a0857cef3..7407d74ce3a87 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td @@ -37,6 +37,7 @@ class Ptr_Type traits = []> def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [ MemRefElementTypeInterface, + PtrLikeTypeInterface, VectorElementTypeInterface, DeclareTypeInterfaceMethods ]; + let extraClassDeclaration = [{ + // `PtrLikeTypeInterface` interface methods. + /// Returns `Type()` as this pointer type is opaque. + Type getElementType() const { + return Type(); + } + /// Clones the pointer with specified memory space or returns failure + /// if an `elementType` was specified or if the memory space doesn't + /// implement `MemorySpaceAttrInterface`. + FailureOr clonePtrWith(Attribute memorySpace, + std::optional elementType) const { + if (elementType) + return failure(); + if (auto ms = dyn_cast(memorySpace)) + return cast(get(ms)); + return failure(); + } + /// `!ptr.ptr` types are seen as ptr-like objects with no metadata. + bool hasPtrMetadata() const { + return false; + } + }]; +} + +def Ptr_PtrMetadata : Ptr_Type<"PtrMetadata", "ptr_metadata"> { + let summary = "Pointer metadata type"; + let description = [{ + The `ptr_metadata` type represents an opaque-view of the metadata associated + with a `ptr-like` object type. + + Note: It's a verification error to construct a `ptr_metadata` type using a + `ptr-like` type with no metadata. + + Example: + + ```mlir + // The metadata associated with a `memref` type. + !ptr.ptr_metadata> + ``` + }]; + let parameters = (ins "PtrLikeTypeInterface":$type); + let assemblyFormat = "`<` $type `>`"; + let builders = [ + TypeBuilderWithInferredContext<(ins + "PtrLikeTypeInterface":$ptrLike), [{ + return $_get(ptrLike.getContext(), ptrLike); + }]> + ]; + let genVerifyDecl = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index 791b95ad3559e..1523762efc18f 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -17,6 +17,72 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/IR/OpAsmInterface.td" +//===----------------------------------------------------------------------===// +// FromPtrOp +//===----------------------------------------------------------------------===// + +def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [ + Pure, OptionalTypesMatchWith<"metadata type", "result", "metadata", + "PtrMetadataType::get(cast($_self))"> + ]> { + let summary = "Casts a `!ptr.ptr` value to a ptr-like value."; + let description = [{ + The `from_ptr` operation casts a `ptr` value to a ptr-like object. It's + important to note that: + - The ptr-like object cannot be a `!ptr.ptr`. + - The memory-space of both the `ptr` and ptr-like object must match. + - The cast is Pure (no UB and side-effect free). + + The optional `metadata` operand exists to provide any ptr-like metadata + that might be required to perform the cast. + + Example: + + ```mlir + %typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> !my.ptr + %memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref + + // Cast the `%ptr` to a memref without utilizing metadata. + %memref = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref + ``` + }]; + + let arguments = (ins Ptr_PtrType:$ptr, Optional:$metadata); + let results = (outs PtrLikeTypeInterface:$result); + let assemblyFormat = [{ + $ptr (`metadata` $metadata^)? attr-dict `:` type($ptr) `->` type($result) + }]; + let hasFolder = 1; + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// GetMetadataOp +//===----------------------------------------------------------------------===// + +def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [ + Pure, TypesMatchWith<"metadata type", "ptr", "result", + "PtrMetadataType::get(cast($_self))"> + ]> { + let summary = "SSA value representing pointer metadata."; + let description = [{ + The `get_metadata` operation produces an opaque value that encodes the + metadata of the ptr-like type. + + Example: + + ```mlir + %metadata = ptr.get_metadata %memref : memref + ``` + }]; + + let arguments = (ins PtrLikeTypeInterface:$ptr); + let results = (outs Ptr_PtrMetadata:$result); + let assemblyFormat = [{ + $ptr attr-dict `:` type($ptr) + }]; +} + //===----------------------------------------------------------------------===// // PtrAddOp //===----------------------------------------------------------------------===// @@ -32,8 +98,8 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [ Example: ```mlir - %x_off = ptr.ptr_add %x, %off : !ptr.ptr<0>, i32 - %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<0>, i32 + %x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32 + %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32 ``` }]; @@ -52,6 +118,36 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [ }]; } +//===----------------------------------------------------------------------===// +// ToPtrOp +//===----------------------------------------------------------------------===// + +def Ptr_ToPtrOp : Pointer_Op<"to_ptr", [Pure]> { + let summary = "Casts a ptr-like value to a `!ptr.ptr` value."; + let description = [{ + The `to_ptr` operation casts a ptr-like object to a `!ptr.ptr`. It's + important to note that: + - The ptr-like object cannot be a `!ptr.ptr`. + - The memory-space of both the `ptr` and ptr-like object must match. + - The cast is side-effect free. + + Example: + + ```mlir + %ptr0 = ptr.to_ptr %my_ptr : !my.ptr -> !ptr.ptr<#ptr.generic_space> + %ptr1 = ptr.to_ptr %memref : memref -> !ptr.ptr<#ptr.generic_space> + ``` + }]; + + let arguments = (ins PtrLikeTypeInterface:$ptr); + let results = (outs Ptr_PtrType:$result); + let assemblyFormat = [{ + $ptr attr-dict `:` type($ptr) `->` type($result) + }]; + let hasFolder = 1; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // TypeOffsetOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td index 4a4f818b46c57..367aeb6ac512b 100644 --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -110,6 +110,59 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> { }]; } +//===----------------------------------------------------------------------===// +// PtrLikeTypeInterface +//===----------------------------------------------------------------------===// + +def PtrLikeTypeInterface : TypeInterface<"PtrLikeTypeInterface"> { + let cppNamespace = "::mlir"; + let description = [{ + A ptr-like type represents an object storing a memory address. This object + is constituted by: + - A memory address called the base pointer. This pointer is treated as a + bag of bits without any assumed structure. The bit-width of the base + pointer must be a compile-time constant. However, the bit-width may remain + opaque or unavailable during transformations that do not depend on the + base pointer. Finally, it is considered indivisible in the sense that as + a `PtrLikeTypeInterface` value, it has no metadata. + - Optional metadata about the pointer. For example, the size of the memory + region associated with the pointer. + + Furthermore, all ptr-like types have two properties: + - The memory space associated with the address held by the pointer. + - An optional element type. If the element type is not specified, the + pointer is considered opaque. + }]; + let methods = [ + InterfaceMethod<[{ + Returns the memory space of this ptr-like type. + }], + "::mlir::Attribute", "getMemorySpace">, + InterfaceMethod<[{ + Returns the element type of this ptr-like type. Note: this method can + return `::mlir::Type()`, in which case the pointer is considered opaque. + }], + "::mlir::Type", "getElementType">, + InterfaceMethod<[{ + Returns whether this ptr-like type has non-empty metadata. + }], + "bool", "hasPtrMetadata">, + InterfaceMethod<[{ + Returns a clone of this type with the given memory space and element type, + or `failure` if the type cannot be cloned with the specified arguments. + If the pointer is opaque and `elementType` is not `std::nullopt` the + method will return `failure`. + + If no `elementType` is provided and ptr is not opaque, the `elementType` + of this type is used. + }], + "::llvm::FailureOr<::mlir::PtrLikeTypeInterface>", "clonePtrWith", (ins + "::mlir::Attribute":$memorySpace, + "::std::optional<::mlir::Type>":$elementType + )> + ]; +} + //===----------------------------------------------------------------------===// // ShapedType //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index df1e02732617d..86ec5c43970b1 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -99,7 +99,9 @@ class TensorType : public Type, public ShapedType::Trait { /// Note: This class attaches the ShapedType trait to act as a mixin to /// provide many useful utility functions. This inheritance has no effect /// on derived memref types. -class BaseMemRefType : public Type, public ShapedType::Trait { +class BaseMemRefType : public Type, + public PtrLikeTypeInterface::Trait, + public ShapedType::Trait { public: using Type::Type; @@ -117,6 +119,12 @@ class BaseMemRefType : public Type, public ShapedType::Trait { BaseMemRefType cloneWith(std::optional> shape, Type elementType) const; + /// Clone this type with the given memory space and element type. If the + /// provided element type is `std::nullopt`, the current element type of the + /// type is used. + FailureOr + clonePtrWith(Attribute memorySpace, std::optional elementType) const; + // Make sure that base class overloads are visible. using ShapedType::Trait::clone; @@ -141,8 +149,16 @@ class BaseMemRefType : public Type, public ShapedType::Trait { /// New `Attribute getMemorySpace()` method should be used instead. unsigned getMemorySpaceAsInt() const; + /// Returns that this ptr-like object has non-empty ptr metadata. + bool hasPtrMetadata() const { return true; } + /// Allow implicit conversion to ShapedType. operator ShapedType() const { return llvm::cast(*this); } + + /// Allow implicit conversion to PtrLikeTypeInterface. + operator PtrLikeTypeInterface() const { + return llvm::cast(*this); + } }; } // namespace mlir diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 771de01fc8d5d..9ad24e45c8315 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -562,6 +562,7 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer", //===----------------------------------------------------------------------===// def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [ + PtrLikeTypeInterface, ShapedTypeInterface ], "BaseMemRefType"> { let summary = "Shaped reference to a region of memory"; @@ -1143,6 +1144,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> { //===----------------------------------------------------------------------===// def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [ + PtrLikeTypeInterface, ShapedTypeInterface ], "BaseMemRefType"> { let summary = "Shaped reference, with unknown rank, to a region of memory"; diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index c21783011452f..c488144508128 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -41,6 +41,52 @@ void PtrDialect::initialize() { >(); } +//===----------------------------------------------------------------------===// +// FromPtrOp +//===----------------------------------------------------------------------===// + +OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) { + // Fold the pattern: + // %ptr = ptr.to_ptr %v : type -> ptr + // (%mda = ptr.get_metadata %v : type)? + // %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type + // To: + // %val -> %v + Value ptrLike; + FromPtrOp fromPtr = *this; + while (fromPtr != nullptr) { + auto toPtr = dyn_cast_or_null(fromPtr.getPtr().getDefiningOp()); + // Cannot fold if it's not a `to_ptr` op or the initial and final types are + // different. + if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType()) + return ptrLike; + Value md = fromPtr.getMetadata(); + // If the type has trivial metadata fold. + if (!fromPtr.getType().hasPtrMetadata()) { + ptrLike = toPtr.getPtr(); + } else if (md) { + // Fold if the metadata can be verified to be equal. + if (auto mdOp = dyn_cast_or_null(md.getDefiningOp()); + mdOp && mdOp.getPtr() == toPtr.getPtr()) + ptrLike = toPtr.getPtr(); + } + // Check for a sequence of casts. + fromPtr = dyn_cast_or_null(ptrLike ? ptrLike.getDefiningOp() + : nullptr); + } + return ptrLike; +} + +LogicalResult FromPtrOp::verify() { + if (isa(getType())) + return emitError() << "the result type cannot be `!ptr.ptr`"; + if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) { + return emitError() + << "expected the input and output to have the same memory space"; + } + return success(); +} + //===----------------------------------------------------------------------===// // PtrAddOp //===----------------------------------------------------------------------===// @@ -55,6 +101,40 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// ToPtrOp +//===----------------------------------------------------------------------===// + +OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) { + // Fold the pattern: + // %val = ptr.from_ptr %p (metadata ...)? : ptr -> type + // %ptr = ptr.to_ptr %val : type -> ptr + // To: + // %ptr -> %p + Value ptr; + ToPtrOp toPtr = *this; + while (toPtr != nullptr) { + auto fromPtr = dyn_cast_or_null(toPtr.getPtr().getDefiningOp()); + // Cannot fold if it's not a `from_ptr` op. + if (!fromPtr) + return ptr; + ptr = fromPtr.getPtr(); + // Check for chains of casts. + toPtr = dyn_cast_or_null(ptr.getDefiningOp()); + } + return ptr; +} + +LogicalResult ToPtrOp::verify() { + if (isa(getPtr().getType())) + return emitError() << "the input value cannot be of type `!ptr.ptr`"; + if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) { + return emitError() + << "expected the input and output to have the same memory space"; + } + return success(); +} + //===----------------------------------------------------------------------===// // TypeOffsetOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp index cab9ca11e679e..7ad2a6bc4c80b 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp @@ -151,3 +151,15 @@ LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries, } return success(); } + +//===----------------------------------------------------------------------===// +// Pointer metadata +//===----------------------------------------------------------------------===// + +LogicalResult +PtrMetadataType::verify(function_ref emitError, + PtrLikeTypeInterface type) { + if (!type.hasPtrMetadata()) + return emitError() << "the ptr-like type has no metadata"; + return success(); +} diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index d47e360e9dc13..97bab479c79bf 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -376,6 +376,20 @@ BaseMemRefType BaseMemRefType::cloneWith(std::optional> shape, return builder; } +FailureOr +BaseMemRefType::clonePtrWith(Attribute memorySpace, + std::optional elementType) const { + Type eTy = elementType ? *elementType : getElementType(); + if (llvm::dyn_cast(*this)) + return cast( + UnrankedMemRefType::get(eTy, memorySpace)); + + MemRefType::Builder builder(llvm::cast(*this)); + builder.setElementType(eTy); + builder.setMemorySpace(memorySpace); + return cast(static_cast(builder)); +} + MemRefType BaseMemRefType::clone(::llvm::ArrayRef shape, Type elementType) const { return ::llvm::cast(cloneWith(shape, elementType)); diff --git a/mlir/test/Dialect/Ptr/canonicalize.mlir b/mlir/test/Dialect/Ptr/canonicalize.mlir index ad363d554f247..e50cd1b76caf3 100644 --- a/mlir/test/Dialect/Ptr/canonicalize.mlir +++ b/mlir/test/Dialect/Ptr/canonicalize.mlir @@ -13,3 +13,109 @@ func.func @zero_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.gene %res0 = ptr.ptr_add %ptr, %off : !ptr.ptr<#ptr.generic_space>, index return %res0 : !ptr.ptr<#ptr.generic_space> } + +/// Tests the the `from_ptr` folder. +// CHECK-LABEL: @test_from_ptr_0 +// CHECK-SAME: (%[[MEM_REF:.*]]: memref) +func.func @test_from_ptr_0(%mr: memref) -> memref { + // CHECK-NOT: ptr.to_ptr + // CHECK-NOT: ptr.get_metadata + // CHECK-NOT: ptr.from_ptr + // CHECK: return %[[MEM_REF]] + %ptr = ptr.to_ptr %mr : memref -> !ptr.ptr<#ptr.generic_space> + %mda = ptr.get_metadata %mr : memref + %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref + return %res : memref +} + +/// Check the op doesn't fold because folding a ptr-type with metadata requires knowing the origin of the metadata. +// CHECK-LABEL: @test_from_ptr_1 +// CHECK-SAME: (%[[MEM_REF:.*]]: memref) +func.func @test_from_ptr_1(%mr: memref) -> memref { + // CHECK: ptr.to_ptr + // CHECK: ptr.from_ptr + %ptr = ptr.to_ptr %mr : memref -> !ptr.ptr<#ptr.generic_space> + %res = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref + return %res : memref +} + +/// Check that the ops cannot be folded because the metadata cannot be guaranteed to be the same. +// CHECK-LABEL: @test_from_ptr_2 +func.func @test_from_ptr_2(%mr: memref, %md: !ptr.ptr_metadata>) -> memref { + // CHECK: ptr.to_ptr + // CHECK: ptr.from_ptr + %ptr = ptr.to_ptr %mr : memref -> !ptr.ptr<#ptr.generic_space> + %res = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref + return %res : memref +} + +// Check the folding of `to_ptr -> from_ptr` chains. +// CHECK-LABEL: @test_from_ptr_3 +// CHECK-SAME: (%[[MEM_REF:.*]]: memref) +func.func @test_from_ptr_3(%mr0: memref) -> memref { + // CHECK-NOT: ptr.to_ptr + // CHECK-NOT: ptr.from_ptr + // CHECK: return %[[MEM_REF]] + %mda = ptr.get_metadata %mr0 : memref + %ptr0 = ptr.to_ptr %mr0 : memref -> !ptr.ptr<#ptr.generic_space> + %mrf0 = ptr.from_ptr %ptr0 metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref + %ptr1 = ptr.to_ptr %mrf0 : memref -> !ptr.ptr<#ptr.generic_space> + %mrf1 = ptr.from_ptr %ptr1 metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref + return %mrf1 : memref +} + +/// Tests the the `to_ptr` folder. +// CHECK-LABEL: @test_to_ptr_0 +// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space> +func.func @test_to_ptr_0(%ptr: !ptr.ptr<#ptr.generic_space>, %md: !ptr.ptr_metadata>) -> !ptr.ptr<#ptr.generic_space> { + // CHECK: return %[[PTR]] + // CHECK-NOT: ptr.from_ptr + // CHECK-NOT: ptr.to_ptr + %mrf = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref + %res = ptr.to_ptr %mrf : memref -> !ptr.ptr<#ptr.generic_space> + return %res : !ptr.ptr<#ptr.generic_space> +} + +// CHECK-LABEL: @test_to_ptr_1 +// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>) +func.func @test_to_ptr_1(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.generic_space> { + // CHECK-NOT: ptr.from_ptr + // CHECK-NOT: ptr.to_ptr + // CHECK: return %[[PTR]] + %mrf = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref + %res = ptr.to_ptr %mrf : memref -> !ptr.ptr<#ptr.generic_space> + return %res : !ptr.ptr<#ptr.generic_space> +} + +// Check the folding of `from_ptr -> to_ptr` chains. +// CHECK-LABEL: @test_to_ptr_2 +// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space> +func.func @test_to_ptr_2(%ptr0: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.generic_space> { + // CHECK-NOT: ptr.from_ptr + // CHECK-NOT: ptr.to_ptr + // CHECK: return %[[PTR]] + %mrf0 = ptr.from_ptr %ptr0 : !ptr.ptr<#ptr.generic_space> -> memref + %ptr1 = ptr.to_ptr %mrf0 : memref -> !ptr.ptr<#ptr.generic_space> + %mrf1 = ptr.from_ptr %ptr1 : !ptr.ptr<#ptr.generic_space> -> memref + %ptr2 = ptr.to_ptr %mrf1 : memref -> !ptr.ptr<#ptr.generic_space> + %mrf2 = ptr.from_ptr %ptr2 : !ptr.ptr<#ptr.generic_space> -> memref + %res = ptr.to_ptr %mrf2 : memref -> !ptr.ptr<#ptr.generic_space> + return %res : !ptr.ptr<#ptr.generic_space> +} + +// Check the folding of chains with different metadata. +// CHECK-LABEL: @test_cast_chain_folding +// CHECK-SAME: (%[[MEM_REF:.*]]: memref +func.func @test_cast_chain_folding(%mr: memref, %md: !ptr.ptr_metadata>) -> memref { + // CHECK-NOT: ptr.to_ptr + // CHECK-NOT: ptr.from_ptr + // CHECK: return %[[MEM_REF]] + %ptr1 = ptr.to_ptr %mr : memref -> !ptr.ptr<#ptr.generic_space> + %memrefWithOtherMd = ptr.from_ptr %ptr1 metadata %md : !ptr.ptr<#ptr.generic_space> -> memref + %ptr = ptr.to_ptr %memrefWithOtherMd : memref -> !ptr.ptr<#ptr.generic_space> + %mda = ptr.get_metadata %mr : memref + // The chain can be folded because: the ptr always has the same value because + // `to_ptr` is a loss-less cast and %mda comes from the original memref. + %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref + return %res : memref +} diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir new file mode 100644 index 0000000000000..19fd715e5bba6 --- /dev/null +++ b/mlir/test/Dialect/Ptr/invalid.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s + +/// Test `to_ptr` verifiers. +func.func @invalid_to_ptr(%v: memref) { + // expected-error@+1 {{expected the input and output to have the same memory space}} + %r = ptr.to_ptr %v : memref -> !ptr.ptr<#ptr.generic_space> + return +} + +// ----- + +func.func @invalid_to_ptr(%v: !ptr.ptr<#ptr.generic_space>) { + // expected-error@+1 {{the input value cannot be of type `!ptr.ptr`}} + %r = ptr.to_ptr %v : !ptr.ptr<#ptr.generic_space> -> !ptr.ptr<#ptr.generic_space> + return +} diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir index d763ea221944b..eed3272d98da9 100644 --- a/mlir/test/Dialect/Ptr/ops.mlir +++ b/mlir/test/Dialect/Ptr/ops.mlir @@ -17,3 +17,13 @@ func.func @ptr_add_type_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<# %res3 = ptr.ptr_add inbounds %ptr, %off : !ptr.ptr<#ptr.generic_space>, index return %res : !ptr.ptr<#ptr.generic_space> } + +/// Check cast ops assembly. +// CHECK-LABEL: @cast_ops +func.func @cast_ops(%mr: memref) -> memref { + %ptr = ptr.to_ptr %mr : memref -> !ptr.ptr<#ptr.generic_space> + %mda = ptr.get_metadata %mr : memref + %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref + %mr0 = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref + return %res : memref +}