Skip to content

Commit fe22487

Browse files
committed
[mlir][core|ptr] Add PtrLikeTypeInterface and casting ops to the ptr dialect
This patch adds the `PtrLikeTypeInterface` type interface to identify pointer-like types. This interface is defined as: ``` A ptr-like type represents an object storing a memory address. This object is constituted by: - A memory address called the base pointer. The base pointer is an indivisible object. - Optional metadata about the pointer. For example, the size of the memory region associated with the pointer. Furthermore, all ptr-like types have two properties: - The memory space associated with the address held by the pointer. - An optional element type. If the element type is not specified, the pointer is considered opaque. ``` This patch adds this interface to `!ptr.ptr` and the `memref` type. Furthermore, this patch adds necessary ops and type to handle casting between `!ptr.ptr` and ptr-like types. First, it defines the `!ptr.ptr_metadata` type. An opaque type to represent the metadata of a ptr-like type. The rationale behind adding this type, is that at high-level the metadata of a type like `memref` cannot be specified, as its structure is tied to its lowering. The `ptr.get_metadata` operation was added to extract the opaque pointer metadata. The concrete structure of the metadata is only known when the op is lowered. Finally, this patch adds the `ptr.from_ptr` and `ptr.to_ptr` operations. Allowing to cast back and forth between `!ptr.ptr` and ptr-liker types. ```mlir func.func @func(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> { %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space> %mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space> %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space> return %res : memref<f32, #ptr.generic_space> } ```
1 parent 492d25b commit fe22487

File tree

11 files changed

+418
-1
lines changed

11 files changed

+418
-1
lines changed

mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>
3737

3838
def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
3939
MemRefElementTypeInterface,
40+
PtrLikeTypeInterface,
4041
VectorElementTypeInterface,
4142
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
4243
"areCompatible", "getIndexBitwidth", "verifyEntries",
@@ -63,6 +64,54 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
6364
return $_get(memorySpace.getContext(), memorySpace);
6465
}]>
6566
];
67+
let extraClassDeclaration = [{
68+
// `PtrLikeTypeInterface` interface methods.
69+
/// Returns `Type()` as this pointer type is opaque.
70+
Type getElementType() const {
71+
return Type();
72+
}
73+
/// Clones the pointer with specified memory space or returns failure
74+
/// if an `elementType` was specified or if the memory space doesn't
75+
/// implement `MemorySpaceAttrInterface`.
76+
FailureOr<PtrLikeTypeInterface> clonePtrWith(Attribute memorySpace,
77+
std::optional<Type> elementType) const {
78+
if (elementType)
79+
return failure();
80+
if (auto ms = dyn_cast<MemorySpaceAttrInterface>(memorySpace))
81+
return cast<PtrLikeTypeInterface>(get(ms));
82+
return failure();
83+
}
84+
/// `!ptr.ptr` types are seen as ptr-like objects with no metadata.
85+
bool hasPtrMetadata() const {
86+
return false;
87+
}
88+
}];
89+
}
90+
91+
def Ptr_PtrMetadata : Ptr_Type<"PtrMetadata", "ptr_metadata"> {
92+
let summary = "Pointer metadata type";
93+
let description = [{
94+
The `ptr_metadata` type represents an opaque-view of the metadata associated
95+
with a `ptr-like` object type.
96+
It's an error to get a `ptr_metadata` using `ptr-like` type with no
97+
metadata.
98+
99+
Example:
100+
101+
```mlir
102+
// The metadata associated with a `memref` type.
103+
!ptr.ptr_metadata<memref<f32>>
104+
```
105+
}];
106+
let parameters = (ins "PtrLikeTypeInterface":$type);
107+
let assemblyFormat = "`<` $type `>`";
108+
let builders = [
109+
TypeBuilderWithInferredContext<(ins
110+
"PtrLikeTypeInterface":$ptrLike), [{
111+
return $_get(ptrLike.getContext(), ptrLike);
112+
}]>
113+
];
114+
let genVerifyDecl = 1;
66115
}
67116

68117
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,75 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
1717
include "mlir/Interfaces/ViewLikeInterface.td"
1818
include "mlir/IR/OpAsmInterface.td"
1919

20+
//===----------------------------------------------------------------------===//
21+
// FromPtrOp
22+
//===----------------------------------------------------------------------===//
23+
24+
def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
25+
Pure, OptionalTypesMatchWith<"metadata type", "result", "metadata",
26+
"PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))">
27+
]> {
28+
let summary = "Casts a `!ptr.ptr` value to a ptr-like value.";
29+
let description = [{
30+
The `from_ptr` operation casts a `ptr` value to a ptr-like object. It's
31+
important to note that:
32+
- The ptr-like object cannot be a `!ptr.ptr`.
33+
- The memory-space of both the `ptr` and ptr-like object must match.
34+
- The cast is side-effect free.
35+
36+
If the ptr-like object type has metadata, then the operation expects the
37+
metadata as an argument or expects that the flag `trivial_metadata` is set.
38+
If `trivial_metadata` is set, then it is assumed that the metadata can be
39+
reconstructed statically from the pointer-like type.
40+
41+
Example:
42+
43+
```mlir
44+
%typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<0> -> !my.ptr<f32, 0>
45+
%memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<0> -> memref<f32, 0>
46+
%memref = ptr.from_ptr %ptr trivial_metadata : !ptr.ptr<0> -> memref<f32, 0>
47+
```
48+
}];
49+
50+
let arguments = (ins Ptr_PtrType:$ptr,
51+
Optional<Ptr_PtrMetadata>:$metadata,
52+
UnitProp:$hasTrivialMetadata);
53+
let results = (outs PtrLikeTypeInterface:$result);
54+
let assemblyFormat = [{
55+
$ptr (`metadata` $metadata^)? (`trivial_metadata` $hasTrivialMetadata^)?
56+
attr-dict `:` type($ptr) `->` type($result)
57+
}];
58+
let hasFolder = 1;
59+
let hasVerifier = 1;
60+
}
61+
62+
//===----------------------------------------------------------------------===//
63+
// GetMetadataOp
64+
//===----------------------------------------------------------------------===//
65+
66+
def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [
67+
Pure, TypesMatchWith<"metadata type", "ptr", "result",
68+
"PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))">
69+
]> {
70+
let summary = "SSA value representing pointer metadata.";
71+
let description = [{
72+
The `get_metadata` operation produces an opaque value that encodes the
73+
metadata of the ptr-like type.
74+
75+
Example:
76+
77+
```mlir
78+
%metadata = ptr.get_metadata %memref : memref<?x?xf32>
79+
```
80+
}];
81+
82+
let arguments = (ins PtrLikeTypeInterface:$ptr);
83+
let results = (outs Ptr_PtrMetadata:$result);
84+
let assemblyFormat = [{
85+
$ptr attr-dict `:` type($ptr)
86+
}];
87+
}
88+
2089
//===----------------------------------------------------------------------===//
2190
// PtrAddOp
2291
//===----------------------------------------------------------------------===//
@@ -52,6 +121,36 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
52121
}];
53122
}
54123

124+
//===----------------------------------------------------------------------===//
125+
// ToPtrOp
126+
//===----------------------------------------------------------------------===//
127+
128+
def Ptr_ToPtrOp : Pointer_Op<"to_ptr", [Pure]> {
129+
let summary = "Casts a ptr-like value to a `!ptr.ptr` value.";
130+
let description = [{
131+
The `to_ptr` operation casts a ptr-like object to a `!ptr.ptr`. It's
132+
important to note that:
133+
- The ptr-like object cannot be a `!ptr.ptr`.
134+
- The memory-space of both the `ptr` and ptr-like object must match.
135+
- The cast is side-effect free.
136+
137+
Example:
138+
139+
```mlir
140+
%ptr0 = ptr.to_ptr %my_ptr : !my.ptr<f32, 0> -> !ptr.ptr<0>
141+
%ptr1 = ptr.to_ptr %memref : memref<f32, 0> -> !ptr.ptr<0>
142+
```
143+
}];
144+
145+
let arguments = (ins PtrLikeTypeInterface:$ptr);
146+
let results = (outs Ptr_PtrType:$result);
147+
let assemblyFormat = [{
148+
$ptr attr-dict `:` type($ptr) `->` type($result)
149+
}];
150+
let hasFolder = 1;
151+
let hasVerifier = 1;
152+
}
153+
55154
//===----------------------------------------------------------------------===//
56155
// TypeOffsetOp
57156
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/BuiltinTypeInterfaces.td

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,55 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
110110
}];
111111
}
112112

113+
//===----------------------------------------------------------------------===//
114+
// PtrLikeTypeInterface
115+
//===----------------------------------------------------------------------===//
116+
117+
def PtrLikeTypeInterface : TypeInterface<"PtrLikeTypeInterface"> {
118+
let cppNamespace = "::mlir";
119+
let description = [{
120+
A ptr-like type represents an object storing a memory address. This object
121+
is constituted by:
122+
- A memory address called the base pointer. The base pointer is an
123+
indivisible object.
124+
- Optional metadata about the pointer. For example, the size of the memory
125+
region associated with the pointer.
126+
127+
Furthermore, all ptr-like types have two properties:
128+
- The memory space associated with the address held by the pointer.
129+
- An optional element type. If the element type is not specified, the
130+
pointer is considered opaque.
131+
}];
132+
let methods = [
133+
InterfaceMethod<[{
134+
Returns the memory space of this ptr-like type.
135+
}],
136+
"::mlir::Attribute", "getMemorySpace">,
137+
InterfaceMethod<[{
138+
Returns the element type of this ptr-like type. Note: this method can
139+
return `::mlir::Type()`, in which case the pointer is considered opaque.
140+
}],
141+
"::mlir::Type", "getElementType">,
142+
InterfaceMethod<[{
143+
Returns whether this ptr-like type has non-empty metadata.
144+
}],
145+
"bool", "hasPtrMetadata">,
146+
InterfaceMethod<[{
147+
Returns a clone of this type with the given memory space and element type,
148+
or `failure` if the type cannot be cloned with the specified arguments.
149+
If the pointer is opaque and `elementType` is not `std::nullopt` the
150+
method will return `failure`.
151+
152+
If no `elementType` is provided and ptr is not opaque, the `elementType`
153+
of this type is used.
154+
}],
155+
"::llvm::FailureOr<::mlir::PtrLikeTypeInterface>", "clonePtrWith", (ins
156+
"::mlir::Attribute":$memorySpace,
157+
"::std::optional<::mlir::Type>":$elementType
158+
)>
159+
];
160+
}
161+
113162
//===----------------------------------------------------------------------===//
114163
// ShapedType
115164
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ class TensorType : public Type, public ShapedType::Trait<TensorType> {
9999
/// Note: This class attaches the ShapedType trait to act as a mixin to
100100
/// provide many useful utility functions. This inheritance has no effect
101101
/// on derived memref types.
102-
class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
102+
class BaseMemRefType : public Type,
103+
public PtrLikeTypeInterface::Trait<BaseMemRefType>,
104+
public ShapedType::Trait<BaseMemRefType> {
103105
public:
104106
using Type::Type;
105107

@@ -117,6 +119,12 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
117119
BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
118120
Type elementType) const;
119121

122+
/// Clone this type with the given memory space and element type. If the
123+
/// provided element type is `std::nullopt`, the current element type of the
124+
/// type is used.
125+
FailureOr<PtrLikeTypeInterface>
126+
clonePtrWith(Attribute memorySpace, std::optional<Type> elementType) const;
127+
120128
// Make sure that base class overloads are visible.
121129
using ShapedType::Trait<BaseMemRefType>::clone;
122130

@@ -141,8 +149,16 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
141149
/// New `Attribute getMemorySpace()` method should be used instead.
142150
unsigned getMemorySpaceAsInt() const;
143151

152+
/// Returns that this ptr-like object has non-empty ptr metadata.
153+
bool hasPtrMetadata() const { return true; }
154+
144155
/// Allow implicit conversion to ShapedType.
145156
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
157+
158+
/// Allow implicit conversion to PtrLikeTypeInterface.
159+
operator PtrLikeTypeInterface() const {
160+
return llvm::cast<PtrLikeTypeInterface>(*this);
161+
}
146162
};
147163

148164
} // namespace mlir

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
562562
//===----------------------------------------------------------------------===//
563563

564564
def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
565+
PtrLikeTypeInterface,
565566
ShapedTypeInterface
566567
], "BaseMemRefType"> {
567568
let summary = "Shaped reference to a region of memory";
@@ -1143,6 +1144,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
11431144
//===----------------------------------------------------------------------===//
11441145

11451146
def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
1147+
PtrLikeTypeInterface,
11461148
ShapedTypeInterface
11471149
], "BaseMemRefType"> {
11481150
let summary = "Shaped reference, with unknown rank, to a region of memory";

mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,54 @@ void PtrDialect::initialize() {
4141
>();
4242
}
4343

44+
//===----------------------------------------------------------------------===//
45+
// FromPtrOp
46+
//===----------------------------------------------------------------------===//
47+
48+
OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
49+
// Fold the pattern:
50+
// %ptr = ptr.to_ptr %v : type -> ptr
51+
// (%mda = ptr.get_metadata %v : type)?
52+
// %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
53+
// To:
54+
// %val -> %v
55+
auto toPtr = dyn_cast_or_null<ToPtrOp>(getPtr().getDefiningOp());
56+
// Cannot fold if it's not a `to_ptr` op or the initial and final types are
57+
// different.
58+
if (!toPtr || toPtr.getPtr().getType() != getType())
59+
return nullptr;
60+
Value md = getMetadata();
61+
if (!md)
62+
return toPtr.getPtr();
63+
// Fold if the metadata can be verified to be equal.
64+
if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
65+
mdOp && mdOp.getPtr() == toPtr.getPtr())
66+
return toPtr.getPtr();
67+
return nullptr;
68+
}
69+
70+
LogicalResult FromPtrOp::verify() {
71+
if (isa<PtrType>(getType()))
72+
return emitError() << "the result type cannot be `!ptr.ptr`";
73+
if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
74+
return emitError()
75+
<< "expected the input and output to have the same memory space";
76+
}
77+
bool hasMD = getMetadata() != Value();
78+
bool hasTrivialMD = getHasTrivialMetadata();
79+
if (hasMD && hasTrivialMD) {
80+
return emitError() << "expected either a metadata argument or the "
81+
"`trivial_metadata` flag, not both";
82+
}
83+
if (getType().hasPtrMetadata() && !(hasMD || hasTrivialMD)) {
84+
return emitError() << "expected either a metadata argument or the "
85+
"`trivial_metadata` flag to be set";
86+
}
87+
if (!getType().hasPtrMetadata() && (hasMD || hasTrivialMD))
88+
return emitError() << "expected no metadata specification";
89+
return success();
90+
}
91+
4492
//===----------------------------------------------------------------------===//
4593
// PtrAddOp
4694
//===----------------------------------------------------------------------===//
@@ -55,6 +103,33 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
55103
return nullptr;
56104
}
57105

106+
//===----------------------------------------------------------------------===//
107+
// ToPtrOp
108+
//===----------------------------------------------------------------------===//
109+
110+
OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
111+
// Fold the pattern:
112+
// %val = ptr.from_ptr %p (metadata ...)? : ptr -> type
113+
// %ptr = ptr.to_ptr %val : type -> ptr
114+
// To:
115+
// %ptr -> %p
116+
auto fromPtr = dyn_cast_or_null<FromPtrOp>(getPtr().getDefiningOp());
117+
// Cannot fold if it's not a `from_ptr` op.
118+
if (!fromPtr)
119+
return nullptr;
120+
return fromPtr.getPtr();
121+
}
122+
123+
LogicalResult ToPtrOp::verify() {
124+
if (isa<PtrType>(getPtr().getType()))
125+
return emitError() << "the input value cannot be of type `!ptr.ptr`";
126+
if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
127+
return emitError()
128+
<< "expected the input and output to have the same memory space";
129+
}
130+
return success();
131+
}
132+
58133
//===----------------------------------------------------------------------===//
59134
// TypeOffsetOp
60135
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,15 @@ LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries,
151151
}
152152
return success();
153153
}
154+
155+
//===----------------------------------------------------------------------===//
156+
// Pointer metadata
157+
//===----------------------------------------------------------------------===//
158+
159+
LogicalResult
160+
PtrMetadataType::verify(function_ref<InFlightDiagnostic()> emitError,
161+
PtrLikeTypeInterface type) {
162+
if (!type.hasPtrMetadata())
163+
return emitError() << "the ptr-like type has no metadata";
164+
return success();
165+
}

0 commit comments

Comments
 (0)