Skip to content

Commit 1a65e63

Browse files
fabianmcgjoker-eph
andauthored
[mlir][ptr] Add ConstantOp with NullAttr and AddressAttr support (#157347)
This patch introduces the `ptr.constant` operation. It also adds the `NullAttr` and `AddressAttr` for representing null pointers, and integer raw addresses. It also implements LLVM IR translation for `ptr.constant` with `#ptr.null` or `#ptr.address` attributes. Finally, it extends `FieldParser` to support APInt parsing. Example: ```mlir llvm.func @constant_address_op() -> !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)> { %0 = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<0>> %1 = ptr.constant #ptr.address<0x1000> : !ptr.ptr<#llvm.address_space<1>> %2 = ptr.constant #ptr.address<3735928559> : !ptr.ptr<#llvm.address_space<2>> %3 = llvm.mlir.poison : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)> %4 = llvm.insertvalue %0, %3[0] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)> %5 = llvm.insertvalue %1, %4[1] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)> %6 = llvm.insertvalue %2, %5[2] : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)> llvm.return %6 : !llvm.struct<(!ptr.ptr<#llvm.address_space<0>>, !ptr.ptr<#llvm.address_space<1>>, !ptr.ptr<#llvm.address_space<2>>)> } ``` Result of translation to LLVM IR: ```llvm define { ptr, ptr addrspace(1), ptr addrspace(2) } @constant_address_op() { ret { ptr, ptr addrspace(1), ptr addrspace(2) } { ptr null, ptr addrspace(1) inttoptr (i64 4096 to ptr addrspace(1)), ptr addrspace(2) inttoptr (i64 3735928559 to ptr addrspace(2)) } } ``` This patch also changes all the `convert*` occurrences in function names or comments to `translate` in the PtrToLLVM file. --------- Co-authored-by: Mehdi Amini <[email protected]>
1 parent bfedb4a commit 1a65e63

File tree

8 files changed

+262
-66
lines changed

8 files changed

+262
-66
lines changed

mlir/include/mlir/Dialect/Ptr/IR/PtrAttrDefs.td

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,34 @@ class Ptr_Attr<string name, string attrMnemonic,
2222
let mnemonic = attrMnemonic;
2323
}
2424

25+
//===----------------------------------------------------------------------===//
26+
// AddressAttr
27+
//===----------------------------------------------------------------------===//
28+
29+
def Ptr_AddressAttr : Ptr_Attr<"Address", "address", [
30+
DeclareAttrInterfaceMethods<TypedAttrInterface>
31+
]> {
32+
let summary = "Address attribute";
33+
let description = [{
34+
The `address` attribute represents a raw memory address, expressed in bytes.
35+
36+
Example:
37+
38+
```mlir
39+
#ptr.address<0x1000> : !ptr.ptr<#ptr.generic_space>
40+
```
41+
}];
42+
let parameters = (ins AttributeSelfTypeParameter<"", "PtrType">:$type,
43+
APIntParameter<"">:$value);
44+
let builders = [
45+
AttrBuilderWithInferredContext<(ins "PtrType":$type,
46+
"const llvm::APInt &":$value), [{
47+
return $_get(type.getContext(), type, value);
48+
}]>
49+
];
50+
let assemblyFormat = "`<` $value `>`";
51+
}
52+
2553
//===----------------------------------------------------------------------===//
2654
// GenericSpaceAttr
2755
//===----------------------------------------------------------------------===//
@@ -37,16 +65,42 @@ def Ptr_GenericSpaceAttr :
3765
- Load and store operations are always valid, regardless of the type.
3866
- Atomic operations are always valid, regardless of the type.
3967
- Cast operations to `generic_space` are always valid.
40-
68+
4169
Example:
4270

4371
```mlir
44-
#ptr.generic_space
72+
#ptr.generic_space : !ptr.ptr<#ptr.generic_space>
4573
```
4674
}];
4775
let assemblyFormat = "";
4876
}
4977

78+
//===----------------------------------------------------------------------===//
79+
// NullAttr
80+
//===----------------------------------------------------------------------===//
81+
82+
def Ptr_NullAttr : Ptr_Attr<"Null", "null", [
83+
DeclareAttrInterfaceMethods<TypedAttrInterface>
84+
]> {
85+
let summary = "Null pointer attribute";
86+
let description = [{
87+
The `null` attribute represents a null pointer.
88+
89+
Example:
90+
91+
```mlir
92+
#ptr.null
93+
```
94+
}];
95+
let parameters = (ins AttributeSelfTypeParameter<"", "PtrType">:$type);
96+
let builders = [
97+
AttrBuilderWithInferredContext<(ins "PtrType":$type), [{
98+
return $_get(type.getContext(), type);
99+
}]>
100+
];
101+
let assemblyFormat = "";
102+
}
103+
50104
//===----------------------------------------------------------------------===//
51105
// SpecAttr
52106
//===----------------------------------------------------------------------===//
@@ -62,7 +116,7 @@ def Ptr_SpecAttr : Ptr_Attr<"Spec", "spec"> {
62116
- [Optional] index: bitwidth that should be used when performing index
63117
computations for the type. Setting the field to `kOptionalSpecValue`, means
64118
the field is optional.
65-
119+
66120
Furthermore, the attribute will verify that all present values are divisible
67121
by 8 (number of bits in a byte), and that `preferred` > `abi`.
68122

mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121
#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
2222
#include "mlir/Dialect/Ptr/IR/PtrEnums.h"
2323

24+
namespace mlir {
25+
namespace ptr {
26+
class PtrType;
27+
} // namespace ptr
28+
} // namespace mlir
29+
2430
#define GET_ATTRDEF_CLASSES
2531
#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.h.inc"
2632

mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
3636
/*cppType=*/"::mlir::ShapedType">;
3737

3838
// A ptr-like type, either scalar or shaped type with value semantics.
39-
def Ptr_PtrLikeType :
39+
def Ptr_PtrLikeType :
4040
AnyTypeOf<[Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>, Ptr_PtrType]>;
4141

4242
// An int-like type, either scalar or shaped type with value semantics.
@@ -57,6 +57,31 @@ def Ptr_Mask1DType :
5757
def Ptr_Ptr1DType :
5858
Ptr_ShapedValueType<[Ptr_PtrType], [HasAnyRankOfPred<[1]>]>;
5959

60+
//===----------------------------------------------------------------------===//
61+
// ConstantOp
62+
//===----------------------------------------------------------------------===//
63+
64+
def Ptr_ConstantOp : Pointer_Op<"constant", [
65+
ConstantLike, Pure, AllTypesMatch<["value", "result"]>
66+
]> {
67+
let summary = "Pointer constant operation";
68+
let description = [{
69+
The `constant` operation produces a pointer constant. The attribute must be
70+
a typed attribute of pointer type.
71+
72+
Example:
73+
74+
```mlir
75+
// Create a null pointer
76+
%null = ptr.constant #ptr.null : !ptr.ptr<#ptr.generic_space>
77+
```
78+
}];
79+
let arguments = (ins TypedAttrInterface:$value);
80+
let results = (outs Ptr_PtrType:$result);
81+
let assemblyFormat = "attr-dict $value";
82+
let hasFolder = 1;
83+
}
84+
6085
//===----------------------------------------------------------------------===//
6186
// FromPtrOp
6287
//===----------------------------------------------------------------------===//
@@ -81,7 +106,7 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
81106
```mlir
82107
%typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> !my.ptr<f32, #ptr.generic_space>
83108
%memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
84-
109+
85110
// Cast the `%ptr` to a memref without utilizing metadata.
86111
%memref = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
87112
```
@@ -361,13 +386,13 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
361386
// Scalar base and offset
362387
%x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
363388
%x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
364-
389+
365390
// Shaped base with scalar offset
366391
%ptrs_off = ptr.ptr_add %ptrs, %off : vector<4x!ptr.ptr<#ptr.generic_space>>, i32
367-
392+
368393
// Scalar base with shaped offset
369394
%x_offs = ptr.ptr_add %x, %offs : !ptr.ptr<#ptr.generic_space>, vector<4xi32>
370-
395+
371396
// Both base and offset are shaped
372397
%ptrs_offs = ptr.ptr_add %ptrs, %offs : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xi32>
373398
```
@@ -382,7 +407,7 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
382407
}];
383408
let hasFolder = 1;
384409
let extraClassDeclaration = [{
385-
/// `ViewLikeOp::getViewSource` method.
410+
/// `ViewLikeOp::getViewSource` method.
386411
Value getViewSource() { return getBase(); }
387412

388413
/// Returns the ptr type of the operation.
@@ -418,7 +443,7 @@ def Ptr_ScatterOp : Pointer_Op<"scatter", [
418443
// Scatter values to multiple memory locations
419444
ptr.scatter %value, %ptrs, %mask :
420445
vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
421-
446+
422447
// Scatter with alignment
423448
ptr.scatter %value, %ptrs, %mask alignment = 8 :
424449
vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>

mlir/include/mlir/IR/DialectImplementation.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,11 @@ struct FieldParser<
103103

104104
/// Parse any integer.
105105
template <typename IntT>
106-
struct FieldParser<IntT,
107-
std::enable_if_t<std::is_integral<IntT>::value, IntT>> {
106+
struct FieldParser<IntT, std::enable_if_t<(std::is_integral<IntT>::value ||
107+
std::is_same_v<IntT, llvm::APInt>),
108+
IntT>> {
108109
static FailureOr<IntT> parse(AsmParser &parser) {
109-
IntT value = 0;
110+
IntT value{};
110111
if (parser.parseInteger(value))
111112
return failure();
112113
return value;

mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ verifyAlignment(std::optional<int64_t> alignment,
5656
return success();
5757
}
5858

59+
//===----------------------------------------------------------------------===//
60+
// ConstantOp
61+
//===----------------------------------------------------------------------===//
62+
63+
OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
64+
5965
//===----------------------------------------------------------------------===//
6066
// FromPtrOp
6167
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)