Skip to content

Commit e31d6ce

Browse files
committed
add read-write ops
1 parent a558d65 commit e31d6ce

File tree

4 files changed

+437
-0
lines changed

4 files changed

+437
-0
lines changed

mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ include "mlir/IR/OpAsmInterface.td"
2424

2525
def AlignmentProp : OptionalProp<I64Prop>;
2626

27+
def ContiguityProp : IntArrayProp<I32Prop, "memory access contiguity information">;
28+
2729
//===----------------------------------------------------------------------===//
2830
// Common types
2931
//===----------------------------------------------------------------------===//
@@ -45,6 +47,15 @@ def Ptr_IntLikeType :AnyTypeOf<[
4547
AnySignlessIntegerOrIndex
4648
]>;
4749

50+
// A shaped pointer type with value semantics.
51+
def Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;
52+
53+
// A shaped mask type with value semantics.
54+
def Ptr_ShapedMaskType : Ptr_ShapedValueType<[I1], [HasRankPred]>;
55+
56+
// A shaped mask type with value semantics.
57+
def Ptr_ShapedAnyType : Ptr_ShapedValueType<[AnyType], [HasRankPred]>;
58+
4859
// A shaped value type of rank 1 of any element type.
4960
def Ptr_Any1DType :
5061
Ptr_ShapedValueType<[AnyType], [HasAnyRankOfPred<[1]>]>;
@@ -472,6 +483,127 @@ def Ptr_PtrDiffOp : Pointer_Op<"ptr_diff", [
472483
let hasVerifier = 1;
473484
}
474485

486+
//===----------------------------------------------------------------------===//
487+
// ReadOp
488+
//===----------------------------------------------------------------------===//
489+
490+
def Ptr_ReadOp : Pointer_Op<"read", [
491+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
492+
TypesMatchWith<"result and mask must be compatible", "result", "mask", [{
493+
::llvm::cast<ShapedType>($_self).clone(
494+
IntegerType::get($_self.getContext(), 1))
495+
}]>,
496+
AllTypesMatch<["result", "passthrough"]>,
497+
// Check the shapes are compatible and both use the same shaped container
498+
// type.
499+
AllShapesMatch<["result", "ptr"]>, AllTypeIDsMatch<["result", "ptr"]>
500+
]> {
501+
let summary = "Read operation";
502+
let description = [{
503+
The `read` operation is a high-level operation that performs a read
504+
from multiple memory locations specified by `ptr` based on a mask `mask`.
505+
Elements of the `result`, corresponding to masked-off lanes, are taken from
506+
the `passthrough` operand.
507+
508+
The `mask` operand is a shaped type of `i1` elements that must have the same
509+
shape as the result type.
510+
511+
The `contiguity` property is an integer array with the same rank as `ptr`,
512+
where each element describes memory access contiguity for the corresponding
513+
dimension. The precise semantics of this property are given by:
514+
Let `c1, c2, ..., cn` be the elements of the contiguity array, and
515+
`s1, s2, ..., sn` be the corresponding elements of the `ptr` shape.
516+
The following rules and restrictions apply:
517+
1. `ck` must be strictly positive for all k.
518+
2. `ck` must divide `sk` for all k.
519+
3. Given arbitrary but valid indices `j1, ..., jn`, then the memory ranges
520+
given by:
521+
- `ptr[j1, ..., jk-1, l * ck : (l + 1) * ck, jk+1, ..., jn]` for
522+
`l = 0, 1, ..., sk / ck - 1`
523+
are contiguous for all k.
524+
525+
It is undefined behavior if the pointers in `ptr` do not satisfy the
526+
contiguity constraints specified by `contiguity`.
527+
528+
Depending on the values of `mask` and `contiguity`, the operation can be
529+
lowered to either:
530+
1. A `ptr.load`, if the mask is all ones, and there's a dimension where all
531+
the accesses are contiguous.
532+
2. A `ptr.masked_load`, if the mask is not all ones, and there's a dimension
533+
where all the accesses are contiguous.
534+
3. A `ptr.gather` if the mask is not all ones, and there's no contiguous
535+
dimension.
536+
537+
The alignment property describes the alignment (in bytes) of each contiguous
538+
memory-block being accessed.
539+
540+
Examples:
541+
```mlir
542+
// Read a vector in row-major order
543+
%result = ptr.read %ptr, %mask, %passthrough contiguity = [1, 4] :
544+
vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
545+
546+
// Read a vector in column-major order with alignment
547+
%result = ptr.read %ptr, %mask, %passthrough alignment = 8
548+
contiguity = [4, 1] :
549+
vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
550+
551+
// Gather a vector from memory
552+
%result = ptr.read %ptr, %mask, %passthrough alignment = 8
553+
contiguity = [1, 1] :
554+
vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
555+
```
556+
}];
557+
let arguments = (ins Ptr_ShapedPtrType:$ptr,
558+
Ptr_ShapedMaskType:$mask,
559+
Ptr_ShapedAnyType:$passthrough,
560+
AlignmentProp:$alignment,
561+
ContiguityProp:$contiguity);
562+
let results = (outs Ptr_ShapedAnyType:$result);
563+
let assemblyFormat = [{
564+
$ptr `,` $mask `,` $passthrough (`alignment` `=` $alignment^)?
565+
`contiguity` `=` $contiguity attr-dict `:` type($ptr) `->` type($result)
566+
}];
567+
let builders = [
568+
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$passthrough,
569+
CArg<"unsigned", "0">:$alignment,
570+
CArg<"ArrayRef<int32_t>", "{}">:$contiguity)>
571+
];
572+
let hasVerifier = 1;
573+
let extraClassDeclaration = [{
574+
/// Returns the ptr type of the operation.
575+
PtrType getPtrType() {
576+
return cast<PtrType>(getPtr().getType().getElementType());
577+
}
578+
579+
/// Returns the rank of the shaped operands and result.
580+
unsigned getRank() { return getType().getRank(); }
581+
582+
/// Returns the shape of the shaped operands and result.
583+
ArrayRef<int64_t> getShape() { return getType().getShape(); }
584+
585+
/// Returns a pair `(c, s)` where `c` is the contiguity and `s` the size
586+
/// of the `i`-th dimension.
587+
std::pair<int64_t, int64_t> getContiguityInfo(unsigned i) {
588+
assert(i < getRank() && "Invalid dimension");
589+
return {getContiguity()[i], getShape()[i]};
590+
}
591+
592+
/// Returns true if the `i`-th dimension is contiguous.
593+
bool isContiguous(unsigned i) {
594+
auto [contiguity, size] = getContiguityInfo(i);
595+
return contiguity == size && size > 1;
596+
}
597+
598+
/// Returns true if the read has gather semantics, ie. there's no dimension
599+
/// where all the accesses are contiguous.
600+
bool hasGatherSemantics() {
601+
return !llvm::any_of(llvm::seq<unsigned>(0, getRank()),
602+
[this](unsigned i) { return isContiguous(i); });
603+
}
604+
}];
605+
}
606+
475607
//===----------------------------------------------------------------------===//
476608
// ScatterOp
477609
//===----------------------------------------------------------------------===//
@@ -645,4 +777,120 @@ def Ptr_TypeOffsetOp : Pointer_Op<"type_offset", [Pure]> {
645777
}];
646778
}
647779

780+
//===----------------------------------------------------------------------===//
781+
// WriteOp
782+
//===----------------------------------------------------------------------===//
783+
784+
def Ptr_WriteOp : Pointer_Op<"write", [
785+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
786+
TypesMatchWith<"value and mask must be compatible",
787+
"value", "mask", [{
788+
cast<ShapedType>($_self).clone(IntegerType::get($_self.getContext(), 1))
789+
}]>,
790+
// Check the shapes are compatible and both use the same shaped container
791+
AllShapesMatch<["value", "ptr"]>, AllTypeIDsMatch<["value", "ptr"]>
792+
]> {
793+
let summary = "Write operation";
794+
let description = [{
795+
The `write` operation is a high-level operation that performs a write to
796+
multiple memory locations specified by `ptr` based on a mask `mask`.
797+
Elements of the `value`, corresponding to masked-off lanes, are not written
798+
to memory.
799+
800+
The `mask` operand is a shaped type of `i1` elements that must have the same
801+
shape as the `value` type.
802+
803+
The `contiguity` property is an integer array with the same rank as `ptr`,
804+
where each element describes memory access contiguity for the corresponding
805+
dimension. The precise semantics of this property are given by:
806+
Let `c1, c2, ..., cn` be the elements of the contiguity array, and
807+
`s1, s2, ..., sn` be the corresponding elements of the `ptr` shape.
808+
The following rules and restrictions apply:
809+
1. `ck` must be strictly positive for all k.
810+
2. `ck` must divide `sk` for all k.
811+
3. Given arbitrary but valid indices `j1, ..., jn`, then the memory ranges
812+
given by:
813+
- `ptr[j1, ..., jk-1, l * ck : (l + 1) * ck, jk+1, ..., jn]` for
814+
`l = 0, 1, ..., sk / ck - 1`
815+
are contiguous for all k.
816+
817+
It is undefined behavior if the pointers in `ptr` do not satisfy the
818+
contiguity constraints specified by `contiguity`.
819+
820+
Depending on the values of `mask` and `contiguity`, the operation can be
821+
lowered to either:
822+
1. A `ptr.store`, if the mask is all ones, and there's a dimension where all
823+
the accesses are contiguous.
824+
2. A `ptr.masked_store`, if the mask is not all ones, and there's a dimension
825+
where all the accesses are contiguous.
826+
3. A `ptr.scatter` if the mask is not all ones, and there's no contiguous
827+
dimension.
828+
829+
The alignment property describes the alignment (in bytes) of each contiguous
830+
memory-block being accessed.
831+
832+
Example:
833+
```mlir
834+
// Write a vector in row-major order
835+
ptr.write %value, %ptr, %mask contiguity = [1, 4] :
836+
vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
837+
838+
// Write a vector in column-major order with alignment
839+
ptr.write %value, %ptr, %mask alignment = 8 contiguity = [4, 1] :
840+
vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
841+
842+
// Scatter a vector to memory
843+
ptr.write %value, %ptr, %mask alignment = 8 contiguity = [1, 1] :
844+
vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
845+
```
846+
}];
847+
let arguments = (ins Ptr_ShapedAnyType:$value,
848+
Ptr_ShapedPtrType:$ptr,
849+
Ptr_ShapedMaskType:$mask,
850+
AlignmentProp:$alignment,
851+
ContiguityProp:$contiguity);
852+
let assemblyFormat = [{
853+
$value `,` $ptr `,` $mask (`alignment` `=` $alignment^)?
854+
`contiguity` `=` $contiguity attr-dict `:` type($value) `,` type($ptr)
855+
}];
856+
let builders = [
857+
OpBuilder<(ins "Value":$value, "Value":$ptr, "Value":$mask,
858+
CArg<"unsigned", "0">:$alignment,
859+
CArg<"ArrayRef<int32_t>", "{}">:$contiguity)>
860+
];
861+
let hasVerifier = 1;
862+
let extraClassDeclaration = [{
863+
/// Returns the ptr type of the operation.
864+
PtrType getPtrType() {
865+
return cast<PtrType>(getPtr().getType().getElementType());
866+
}
867+
868+
/// Returns the rank of the shaped operands.
869+
unsigned getRank() { return getPtr().getType().getRank(); }
870+
871+
/// Returns the shape of the shaped operands.
872+
ArrayRef<int64_t> getShape() { return getPtr().getType().getShape(); }
873+
874+
/// Returns a pair `(c, s)` where `c` is the contiguity and `s` the size
875+
/// of the `i`-th dimension.
876+
std::pair<int64_t, int64_t> getContiguityInfo(unsigned i) {
877+
assert(i < getRank() && "Invalid dimension");
878+
return {getContiguity()[i], getShape()[i]};
879+
}
880+
881+
/// Returns true if the `i`-th dimension is contiguous.
882+
bool isContiguous(unsigned i) {
883+
auto [contiguity, size] = getContiguityInfo(i);
884+
return contiguity == size && size > 1;
885+
}
886+
887+
/// Returns true if the write has scatter semantics, ie. there's no
888+
/// dimension where all the accesses are contiguous.
889+
bool hasScatterSemantics() {
890+
return !llvm::any_of(llvm::seq<unsigned>(0, getRank()),
891+
[this](unsigned i) { return isContiguous(i); });
892+
}
893+
}];
894+
}
895+
648896
#endif // PTR_OPS

mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,25 @@ verifyAlignment(std::optional<int64_t> alignment,
5757
return success();
5858
}
5959

60+
/// Verifies that the contiguity array has the right size, all the elements are
61+
/// positive and divide the corresponding shape dimension.
62+
static LogicalResult
63+
verifyContiguityProp(ArrayRef<int32_t> contiguity, ArrayRef<int64_t> shape,
64+
function_ref<InFlightDiagnostic()> emitError) {
65+
if (contiguity.size() != shape.size()) {
66+
return emitError() << "expected contiguity array with " << shape.size()
67+
<< " elements";
68+
}
69+
if (!llvm::all_of(llvm::zip(contiguity, shape), [](auto cs) {
70+
int32_t c = std::get<0>(cs);
71+
return c > 0 && std::get<1>(cs) % c == 0;
72+
})) {
73+
return emitError()
74+
<< "expected contiguity values to be positive and divide the shape";
75+
}
76+
return success();
77+
}
78+
6079
//===----------------------------------------------------------------------===//
6180
// ConstantOp
6281
//===----------------------------------------------------------------------===//
@@ -264,6 +283,49 @@ void MaskedStoreOp::build(OpBuilder &builder, OperationState &state,
264283
alignment ? std::optional<int64_t>(alignment) : std::nullopt);
265284
}
266285

286+
//===----------------------------------------------------------------------===//
287+
// ReadOp
288+
//===----------------------------------------------------------------------===//
289+
290+
void ReadOp::getEffects(
291+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
292+
&effects) {
293+
effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable());
294+
}
295+
296+
LogicalResult ReadOp::verify() {
297+
auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
298+
299+
// Verify that the pointer type's memory space allows loads.
300+
MemorySpaceAttrInterface ms =
301+
cast<PtrType>(getPtr().getType().getElementType()).getMemorySpace();
302+
DataLayout dataLayout = DataLayout::closest(*this);
303+
if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
304+
getAlignment(), &dataLayout, emitDiag))
305+
return failure();
306+
307+
// Verify the alignment.
308+
if (failed(verifyAlignment(getAlignment(), emitDiag)))
309+
return failure();
310+
311+
// Verify the contiguity array.
312+
return verifyContiguityProp(getContiguity(), getShape(), emitDiag);
313+
}
314+
315+
void ReadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
316+
Value mask, Value passthrough, unsigned alignment,
317+
ArrayRef<int32_t> contiguity) {
318+
if (!contiguity.empty()) {
319+
build(builder, state, ptr, mask, passthrough,
320+
alignment ? std::optional<int64_t>(alignment) : std::nullopt,
321+
contiguity);
322+
return;
323+
}
324+
build(builder, state, ptr, mask, passthrough,
325+
alignment ? std::optional<int64_t>(alignment) : std::nullopt,
326+
SmallVector<int32_t>(cast<ShapedType>(ptr.getType()).getRank(), 1));
327+
}
328+
267329
//===----------------------------------------------------------------------===//
268330
// ScatterOp
269331
//===----------------------------------------------------------------------===//
@@ -470,6 +532,49 @@ llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional<DataLayout> layout) {
470532
return dl.getTypeSize(getElementType());
471533
}
472534

535+
//===----------------------------------------------------------------------===//
536+
// WriteOp
537+
//===----------------------------------------------------------------------===//
538+
539+
void WriteOp::getEffects(
540+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
541+
&effects) {
542+
effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable());
543+
}
544+
545+
LogicalResult WriteOp::verify() {
546+
auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
547+
548+
// Verify that the pointer type's memory space allows stores.
549+
MemorySpaceAttrInterface ms =
550+
cast<PtrType>(getPtr().getType().getElementType()).getMemorySpace();
551+
DataLayout dataLayout = DataLayout::closest(*this);
552+
if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
553+
getAlignment(), &dataLayout, emitDiag))
554+
return failure();
555+
556+
// Verify the alignment.
557+
if (failed(verifyAlignment(getAlignment(), emitDiag)))
558+
return failure();
559+
560+
// Verify the contiguity array.
561+
return verifyContiguityProp(getContiguity(), getShape(), emitDiag);
562+
}
563+
564+
void WriteOp::build(OpBuilder &builder, OperationState &state, Value value,
565+
Value ptr, Value mask, unsigned alignment,
566+
ArrayRef<int32_t> contiguity) {
567+
if (!contiguity.empty()) {
568+
build(builder, state, value, ptr, mask,
569+
alignment ? std::optional<int64_t>(alignment) : std::nullopt,
570+
contiguity);
571+
return;
572+
}
573+
build(builder, state, value, ptr, mask,
574+
alignment ? std::optional<int64_t>(alignment) : std::nullopt,
575+
SmallVector<int32_t>(cast<ShapedType>(ptr.getType()).getRank(), 1));
576+
}
577+
473578
//===----------------------------------------------------------------------===//
474579
// Pointer API.
475580
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)